From 9931731a31962d4e22810d158d0ad5ceb9193fd6 Mon Sep 17 00:00:00 2001 From: wasamtc <81901970+wasamtc@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:24:43 +0800 Subject: [PATCH 01/50] Disable Docker image build workflow The workflow has been disabled, preventing any actions from being performed. --- .github/workflows/build-spark-image.yaml | 45 ++++-------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/.github/workflows/build-spark-image.yaml b/.github/workflows/build-spark-image.yaml index 0d3b60d6..f58b3f73 100644 --- a/.github/workflows/build-spark-image.yaml +++ b/.github/workflows/build-spark-image.yaml @@ -5,7 +5,6 @@ on: - cron: '0 3 * * *' workflow_dispatch: - env: IMAGE_NAME: parallax NAMESPACE: gradientservice @@ -23,41 +22,9 @@ jobs: variant: [spark] steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - with: - driver: remote - endpoint: tcp://buildkit-buildkit-service.arc-systems:1234 - - - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Extract metadata - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.NAMESPACE }}/${{ env.IMAGE_NAME }} - tags: | - type=ref,event=branch,suffix=-${{ matrix.variant }} - type=ref,event=pr,suffix=-${{ matrix.variant }} - type=raw,value=latest-${{ matrix.variant }},enable={{is_default_branch}} - - - name: Build and push Docker image - id: build - uses: docker/build-push-action@v5 - with: - context: . - file: ./docker/Dockerfile.${{ matrix.variant }} - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max - platforms: linux/arm64 + - name: Workflow disabled – no actions performed + run: | + echo "==================================================" + echo "This workflow has been DISABLED." + echo "No Docker images are built or pushed." + echo "==================================================" From dd067d6f9a53b5e118d5d440e521339677d7330a Mon Sep 17 00:00:00 2001 From: wasamtc <81901970+wasamtc@users.noreply.github.com> Date: Mon, 19 Jan 2026 09:48:49 +0800 Subject: [PATCH 02/50] Update build-spark-image.yaml --- .github/workflows/build-spark-image.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-spark-image.yaml b/.github/workflows/build-spark-image.yaml index f58b3f73..1474995d 100644 --- a/.github/workflows/build-spark-image.yaml +++ b/.github/workflows/build-spark-image.yaml @@ -1,9 +1,9 @@ name: Build Spark Image on: - schedule: - - cron: '0 3 * * *' - workflow_dispatch: + # schedule: # 注释掉或删除此部分以停止自动触发 + # - cron: '0 3 * * *' + workflow_dispatch: # 仅保留手动触发 env: IMAGE_NAME: parallax From 89697a7120f9b4d5d99691e98f7e9378846fb721 Mon Sep 17 00:00:00 2001 From: wasamtc <81901970+wasamtc@users.noreply.github.com> Date: Mon, 19 Jan 2026 09:50:02 +0800 Subject: [PATCH 03/50] Update build-images.yaml --- .github/workflows/build-images.yaml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-images.yaml b/.github/workflows/build-images.yaml index 64844408..751dc8ad 100644 --- a/.github/workflows/build-images.yaml +++ b/.github/workflows/build-images.yaml @@ -1,10 +1,9 @@ name: Build Images on: - schedule: - - cron: '0 0 * * *' - workflow_dispatch: - + # schedule: # 注释掉定时任务以停止自动构建 + # - cron: '0 0 * * *' + workflow_dispatch: # 保留手动触发开关 env: IMAGE_NAME: parallax @@ -29,7 +28,6 @@ jobs: driver: remote endpoint: tcp://buildkit-buildkit-service.arc-systems:1234 - - name: Log in to Docker Hub uses: docker/login-action@v3 with: From fe7a4c8238962e9d45e659323ed2c0935b612f93 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Thu, 29 Jan 2026 10:38:41 +0000 Subject: [PATCH 04/50] base gpu version:single node can run successly --- src/parallax/server/executor/factory.py | 1 + .../server/executor/sglang_executor.py | 143 +++++++++++++++- src/parallax/server/server_args.py | 8 + src/parallax/sglang/batch_info.py | 161 +++++++++++++----- 4 files changed, 261 insertions(+), 52 deletions(-) diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index 1a366cd8..bc530538 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -54,6 +54,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None, conn=Non "max_lora_chunk_size": args.max_lora_chunk_size, "enable_weight_refit": args.enable_weight_refit, "weight_refit_mode": args.weight_refit_mode, + "chunked_prefill_size": args.chunked_prefill_size, } return config diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index e4a7996f..e855e9c0 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -12,6 +12,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.utils.common import SUPPORTED_LORA_TARGET_MODULES +from sglang.srt.managers.schedule_batch import Req from parallax.server.executor.base_executor import BaseExecutor from parallax.server.request import ( @@ -91,6 +92,8 @@ def __init__( weight_refit_mode: Optional[str] = "disk", # Pipe communication conn: Optional[List[Any]] = [], + chunked_prefill_size: Optional[int] = None, + max_prefill_tokens: Optional[int] = 4 * 1024, ): self.enable_lora = True if lora_paths is not None else enable_lora @@ -102,7 +105,15 @@ def __init__( self.lora_eviction_policy = lora_eviction_policy self.lora_backend = lora_backend self.max_lora_chunk_size = max_lora_chunk_size - + self.chunked_prefill_size = chunked_prefill_size + self.max_prefill_tokens = max_prefill_tokens + # Chunked prefill must be at least one page for correct KV/prefix cache alignment. + if self.chunked_prefill_size is not None and self.chunked_prefill_size < kv_block_size: + logger.info( + f"chunked_prefill_size {self.chunked_prefill_size} < page_size (kv_block_size={kv_block_size}); " + f"clamping to {kv_block_size}" + ) + self.chunked_prefill_size = kv_block_size if self.lora_paths is not None and len(self.lora_paths) > 0: self.check_lora_server_args() @@ -136,6 +147,7 @@ def __init__( "lora_eviction_policy": self.lora_eviction_policy, "lora_backend": self.lora_backend, "max_lora_chunk_size": self.max_lora_chunk_size, + "chunked_prefill_size": self.chunked_prefill_size, } logger.debug( f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -182,6 +194,11 @@ def __init__( self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group + + # add chunked_prefill_size and chunked_req + if chunked_prefill_size is not None and chunked_prefill_size <= 0: + chunked_prefill_size = None + self.chunked_req = None # create a page tree cache for sglang prefill if enable_prefix_cache: @@ -331,11 +348,20 @@ def handle_input_requests(self, requests: List[Request]): # If it's an abort signal (e.g. from OOM), next_token_id might be None or dummy if not req.abort and req.next_token_id is not None: - original_req.commit_new_token(req.next_token_id) - logger.debug( - f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " - f"output_ids now has {len(original_req.output_ids)} tokens" - ) + # Keep chunked_req in PREFILLING so next round it goes to prefill again. + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + original_req.status = RequestStatus.PREFILLING + else: + original_req.commit_new_token(req.next_token_id) + logger.debug( + f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " + f"output_ids now has {len(original_req.output_ids)} tokens" + ) + if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -344,7 +370,18 @@ def handle_input_requests(self, requests: List[Request]): if req.abort: original_req.abort = True - if self.scheduler.check_and_update_request_status(original_req): + # Chunked req is held by executor.chunked_req and enters next prefill via + # add_chunked_req; do not re-enqueue until all chunks are done. + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): + self.chunked_req.is_chunked -= 1 + self.scheduler.enque_request(original_req) + continue + + elif self.scheduler.check_and_update_request_status(original_req): logger.debug(f"Releasing resources for finished request {req.request_id}") self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: @@ -390,6 +427,15 @@ def handle_input_requests(self, requests: List[Request]): else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) + + def stash_chunked_request(self, req: Req): + self.page_tree_cache.cache_unfinished_req(req, chunked=True) + # FIX: below code don't now if it is needed + # # Chunked request keeps its rid but will get a new req_pool_idx + # if self.model_runner.mambaish_config is not None: + # self.model_runner.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=False) + # else: + # self.model_runner.req_to_token_pool.free(req.req_pool_idx) def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): """Process a batch of requests in SGLang.""" @@ -408,9 +454,37 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: pp_proxy_tensors=pp_proxy_tensors, ) logits_output = out.logits_output + # Merge prefill batch into running batch + chunked_req_to_exclude = set() + + if self.chunked_req is not None: + # Move the chunked request out of the batch so that we can merge + # only finished requests to running_batch. + chunked_req_to_exclude.add(self.chunked_req) + self.stash_chunked_request(self.chunked_req) + + + if self.cur_batch and self.cur_batch.forward_mode.is_extend(): + if self.cur_batch.chunked_req is not None: + chunked_req_to_exclude.add(self.cur_batch.chunked_req) + if self.cur_batch: + # Avoid duplicate rids in running_batch: exclude any cur_batch req + # whose rid is already in running_batch (filter_batch uses object identity). + if not self.running_batch.is_empty(): + existing_rids = {req.rid for req in self.running_batch.reqs} + for req in list(self.cur_batch.reqs): + if req.rid in existing_rids: + chunked_req_to_exclude.add(req) + cur_batch_size = self.cur_batch.batch_size() + self.cur_batch.filter_batch( + chunked_req_to_exclude=list[Req](chunked_req_to_exclude) + ) + + if self.cur_batch.batch_size() < cur_batch_size: + self.cur_batch.batch_is_full = False if self.cur_batch.forward_mode.is_extend(): # Merge the new batch into the running batch if not self.cur_batch.is_empty(): @@ -423,8 +497,39 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Return appropriate output based on peer position if return_decoded_tokens: + # Debug: log running_batch vs prepared_inputs["requests"] to detect + # "running_batch still has previous request" causing token mis-assignment + running_batch_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + running_batch_rids = ( + [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + ) + requests_len = len(requests) + requests_ids = [getattr(r, "request_id", str(r)) for r in requests] + logger.debug( + "[ChunkedPrefill-Debug] process_batch decode: running_batch size=%s rids=%s, " + "prepared_inputs requests len=%s request_ids=%s", + running_batch_size, + running_batch_rids, + requests_len, + requests_ids, + ) + if running_batch_size != requests_len: + logger.warning( + "[ChunkedPrefill-Debug] MISMATCH: running_batch has %s reqs but prepared_inputs " + "has %s requests; token indices may be assigned to wrong request. " + "running_batch_rids=%s, request_ids=%s", + running_batch_size, + requests_len, + running_batch_rids, + requests_ids, + ) + # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) + logger.debug( + "[ChunkedPrefill-Debug] process_batch after sample: len(next_token_ids)=%s", + len(next_token_ids), + ) # Only compute probs if any request in the batch needs it # Check if any InitialRequest has return_probs=True @@ -456,7 +561,28 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: def _release_request(self, rid: str): """Release per-request resources in SGLang.""" try: + # Debug: log running_batch before release to verify request eviction + before_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + before_rids = ( + [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + ) + logger.debug( + "[ChunkedPrefill-Debug] _release_request BEFORE: releasing rid=%s, " + "running_batch size=%s rids=%s", + rid, + before_size, + before_rids, + ) release_sglang_request(self.running_batch, rid) + after_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + after_rids = ( + [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + ) + logger.debug( + "[ChunkedPrefill-Debug] _release_request AFTER: running_batch size=%s rids=%s", + after_size, + after_rids, + ) except Exception: pass @@ -599,8 +725,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A schedule_batch, forward_batch = form_sgl_batch_prefill( batched_requests, - self.model_runner, - self.page_tree_cache, + self, ) self.cur_batch = schedule_batch diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 85e6350c..097fadca 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -108,6 +108,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--enable-prefix-cache", action="store_true", help="Enable prefix cache reuse" ) + + # add --chunked-prefill-size + parser.add_argument( + "--chunked-prefill-size", + type=int, + default=None, + help="Chunk size for chunked prefill processing", + ) # Scheduler configuration parser.add_argument( diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 56720040..bb17a2be 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -5,11 +5,15 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ +from __future__ import annotations + from types import SimpleNamespace -from typing import List, Optional +from typing import List, Optional, TYPE_CHECKING +from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -18,13 +22,15 @@ from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from parallax.server.executor.sglang_executor import PageRadixCache from parallax.server.request import Request from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) from parallax_utils.logging_config import get_logger +if TYPE_CHECKING: + from parallax.server.executor.sglang_executor import SGLExecutor + logger = get_logger(__name__) @@ -48,11 +54,41 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S def transform_requests_to_sglang( - old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None + old_requests: List[Request], + executor: SGLExecutor, ) -> List[Req]: """Transforms Parallax Request to SGLang.Req format""" + model_runner = executor.model_runner + page_tree_cache = executor.page_tree_cache + chunked_prefill_size = executor.chunked_prefill_size + # Prefill policy + adder = PrefillAdder( + model_runner.page_size, + page_tree_cache, + model_runner.token_to_kv_pool_allocator, + None, + None, + executor.max_prefill_tokens, + chunked_prefill_size, + 0, + None, + None, + ) + + if executor.chunked_req is not None: + logger.debug(f"before add_chunked_req, chunked_req is not None") + executor.chunked_req.init_next_round_input(page_tree_cache) + executor.chunked_req = adder.add_chunked_req(executor.chunked_req) + if executor.chunked_req is None: + logger.debug(f"after add_chunked_req, chunked_req is None") + + reqs = [] + logger.debug(f"old_req size: {len(old_requests)}") for old_req in old_requests: + # Chunked req is added via add_chunked_req above; skip to avoid double-add. + if executor.chunked_req is not None and old_req.request_id == executor.chunked_req.rid: + continue sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( rid=old_req.request_id, @@ -71,6 +107,18 @@ def transform_requests_to_sglang( ) req.init_next_round_input(page_tree_cache) + + res = adder.add_one_req( + req, executor.chunked_req is not None, None, + ) + + if res != AddReqResult.CONTINUE: + logger.warning(f"Request {old_req.request_id} failed to add to prefill batch, result: {res},\ + req_len: {len(req.origin_input_ids)}") + if res == AddReqResult.NO_TOKEN: + logger.warning(f"there is no token to add to prefill batch") + executor.running_batch.batch_is_full = True + break # Debug: Log after cache lookup if page_tree_cache is not None: @@ -85,17 +133,31 @@ def transform_requests_to_sglang( ) reqs.append(req) - return reqs + + logger.debug(f"new reqs size: {len(reqs)}") + + if adder.new_chunked_req is not None: + # Update chunked prefill + assert executor.chunked_req is None + executor.chunked_req = adder.new_chunked_req + logger.debug(f"new chunked_req is {executor.chunked_req}") + + if executor.chunked_req is not None: + executor.chunked_req.is_chunked += 1 + + return adder.can_run_list def form_sgl_batch_prefill( requests: List[Request], - model_runner: ModelRunner, - page_tree_cache: Optional[PageRadixCache] = None, + executor: SGLExecutor, ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" + model_runner = executor.model_runner + page_tree_cache = executor.page_tree_cache + - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache) + sgl_reqs = transform_requests_to_sglang(requests, executor) def dummy_evict(*args): pass @@ -115,6 +177,7 @@ def dummy_evict(*args): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, + chunked_req=executor.chunked_req, ) schedule_batch.prepare_for_extend() @@ -199,6 +262,14 @@ def find_index(running_batch: ScheduleBatch, request_id: str): return -1 +def find_index_safe(running_batch: ScheduleBatch, request_id: str) -> int: + """Find first index of request_id in running_batch; return -1 if not found (no log).""" + for index, req in enumerate(running_batch.reqs): + if req.rid == request_id: + return index + return -1 + + def form_sgl_batch_decode( requests: List[Request], model_runner: ModelRunner, @@ -249,42 +320,46 @@ def form_sgl_batch_decode( def release_sglang_request(running_batch: ScheduleBatch, request_id: str): - """Release KV Cache and other resources for finished/aborted requests.""" + """Release KV Cache and other resources for finished/aborted requests. + Removes all entries with the given request_id (handles duplicate rids).""" if running_batch is None or running_batch.is_empty(): return - seq_lens_cpu = running_batch.seq_lens.cpu().numpy() - idx = find_index(running_batch, request_id) - req = running_batch.reqs.pop(idx) - - # use running batch's tree cache to release kv cache - tree_cache = running_batch.tree_cache - - # for completed requests, is_insert=True to insert into prefix cache - # for aborted requests, is_insert=False to not insert into prefix cache - is_insert = True # can be adjusted based on request status - - if isinstance(tree_cache, PageRadixCache): - tree_cache.cache_finished_req(req) - else: - page_size = running_batch.token_to_kv_pool_allocator.page_size - last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size - end_pos = last_uncached_pos + seq_lens_cpu[idx] - running_batch.seq_lens = torch.cat( - (running_batch.seq_lens[:idx], running_batch.seq_lens[idx + 1 :]) - ) - running_batch.seq_lens_cpu = torch.cat( - (running_batch.seq_lens_cpu[:idx], running_batch.seq_lens_cpu[idx + 1 :]) - ) - running_batch.orig_seq_lens = torch.cat( - (running_batch.orig_seq_lens[:idx], running_batch.orig_seq_lens[idx + 1 :]) - ) + while True: + idx = find_index_safe(running_batch, request_id) + if idx < 0: + break + seq_lens_cpu = running_batch.seq_lens.cpu().numpy() + req = running_batch.reqs.pop(idx) + + # use running batch's tree cache to release kv cache + tree_cache = running_batch.tree_cache + + # for completed requests, is_insert=True to insert into prefix cache + # for aborted requests, is_insert=False to not insert into prefix cache + is_insert = True # can be adjusted based on request status + + if isinstance(tree_cache, PageRadixCache): + tree_cache.cache_finished_req(req) + else: + page_size = running_batch.token_to_kv_pool_allocator.page_size + last_uncached_pos = (len(req.prefix_indices) // page_size) * page_size + end_pos = last_uncached_pos + seq_lens_cpu[idx] + running_batch.seq_lens = torch.cat( + (running_batch.seq_lens[:idx], running_batch.seq_lens[idx + 1 :]) + ) + running_batch.seq_lens_cpu = torch.cat( + (running_batch.seq_lens_cpu[:idx], running_batch.seq_lens_cpu[idx + 1 :]) + ) + running_batch.orig_seq_lens = torch.cat( + (running_batch.orig_seq_lens[:idx], running_batch.orig_seq_lens[idx + 1 :]) + ) - # Free kv cache - token_indices = running_batch.req_to_token_pool.req_to_token[req.req_pool_idx][ - last_uncached_pos:end_pos - ] - running_batch.token_to_kv_pool_allocator.free(token_indices) - running_batch.req_to_token_pool.free(req.req_pool_idx) - running_batch.req_pool_indices = torch.cat( - (running_batch.req_pool_indices[:idx], running_batch.req_pool_indices[idx + 1 :]) - ) + # Free kv cache + token_indices = running_batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + last_uncached_pos:end_pos + ] + running_batch.token_to_kv_pool_allocator.free(token_indices) + running_batch.req_to_token_pool.free(req.req_pool_idx) + running_batch.req_pool_indices = torch.cat( + (running_batch.req_pool_indices[:idx], running_batch.req_pool_indices[idx + 1 :]) + ) From d266c14bc20924cb38a0a5e0561adfd5e1e72159 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Fri, 30 Jan 2026 08:35:12 +0000 Subject: [PATCH 05/50] fix[chunked-prefill]:reordered list use origin list --- src/parallax/server/executor/base_executor.py | 51 +++++++---- .../server/executor/sglang_executor.py | 87 +++++++++++++------ src/parallax/server/server_args.py | 2 +- src/parallax/sglang/batch_info.py | 69 +++++++++++---- 4 files changed, 144 insertions(+), 65 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 8770fb62..eaaa7441 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -386,9 +386,15 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> List[Request]: + ) -> Tuple[List[Request], List[Request]]: """Prepares a batch of requests for the next stage of the pipeline. + Returns two lists: + - chunked_reqs: requests that still have chunks to complete (e.g. chunked prefill); + these should be handled locally via handle_input_requests only, not sent to next peer. + - to_forward_reqs: requests to forward (single node: handle_input_requests; + multi-node and tp_rank==0: request_to_proto and send to next peer). + Args: requests: List of requests in the batch batch_output: Output from process_batch. Always a dict with: @@ -439,7 +445,10 @@ def prepare_next_batch_requests( ) batched_requests.append(next_req) - return batched_requests + # Default: no chunked reqs; all go to to_forward. + chunked_reqs: List[Request] = [] + to_forward_reqs: List[Request] = batched_requests + return (chunked_reqs, to_forward_reqs) def release_and_evict_request(self, rid: str): """Release per-request resources and evict from scheduler. Best-effort, never raises.""" @@ -547,28 +556,34 @@ def run_loop(self): except Exception: pass # 7. Prepare requests for the next stage in the pipeline - next_batch = self.prepare_next_batch_requests( + chunked_reqs, to_forward_reqs = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) # 8. Dispatch to the appropriate destination - if self.is_last_peer and self.is_first_peer: - # Single node: handle locally - self.handle_input_requests(next_batch) - elif self.tp_rank == 0: - # Send output to next peer - self.send_to_peer_socket.send_multipart( - [ - b"forward", - request_to_proto(next_batch, self.device).SerializeToString(), - ] - ) - logger.debug( - f"Processed batch of type {batch_type} with {len(next_batch)} requests " - f"in {(time.time() - start_time) * 1000:.3f} ms" - ) + # Chunked reqs (e.g. chunked prefill not yet done) stay local. + if chunked_reqs: + self.handle_input_requests(chunked_reqs) + if to_forward_reqs: + if self.is_last_peer and self.is_first_peer: + # Single node: handle to_forward locally + self.handle_input_requests(to_forward_reqs) + elif self.tp_rank == 0: + # Send to_forward to next peer (do not send chunked_reqs) + self.send_to_peer_socket.send_multipart( + [ + b"forward", + request_to_proto( + to_forward_reqs, self.device + ).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of type {batch_type} with {len(to_forward_reqs)} to_forward " + f"(chunked={len(chunked_reqs)}) in {(time.time() - start_time) * 1000:.3f} ms" + ) except Exception as e: logger.exception(f"Error processing batch: {e}") diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index e855e9c0..dbe64dfd 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -7,12 +7,11 @@ import torch from sglang.srt.lora.lora_registry import LoRARef -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.utils.common import SUPPORTED_LORA_TARGET_MODULES -from sglang.srt.managers.schedule_batch import Req from parallax.server.executor.base_executor import BaseExecutor from parallax.server.request import ( @@ -194,7 +193,7 @@ def __init__( self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) self.tp_group = self.model_runner.tp_group self.tp_cpu_group = self.tp_group.cpu_group - + # add chunked_prefill_size and chunked_req if chunked_prefill_size is not None and chunked_prefill_size <= 0: chunked_prefill_size = None @@ -319,6 +318,45 @@ def check_lora_server_args(self): and (self.max_lora_chunk_size & (self.max_lora_chunk_size - 1)) == 0 ), "--max-lora-chunk-size must be a power of 2 between 16 and 128." + def stash_chunked_request(self, req: Req): + logger.debug(f"req.fill_ids_size: {len(req.fill_ids)}") + # #endregion + if req.req_pool_idx is None: + logger.warning( + "stash_chunked_request: skipping cache_unfinished_req and free because req.req_pool_idx is None (rid=%s, fill_ids_len=%s)", + getattr(req, "rid", None), + len(req.fill_ids), + ) + return + self.page_tree_cache.cache_unfinished_req(req, chunked=True) + # FIX: below code don't now if it is needed + # # Chunked request keeps its rid but will get a new req_pool_idx + if self.model_runner.mambaish_config is not None: + self.model_runner.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=False) + else: + self.model_runner.req_to_token_pool.free(req.req_pool_idx) + + def prepare_next_batch_requests( + self, requests: List[Request], batch_output: Any, context_lengths: Any + ) -> Tuple[List[Request], List[Request]]: + """Split out chunked prefill reqs; set their status to PREFILLING and return (chunked, to_forward).""" + base_chunked, base_to_forward = super().prepare_next_batch_requests( + requests, batch_output, context_lengths + ) + if self.chunked_req is None or self.chunked_req.is_chunked <= 0: + return (base_chunked, base_to_forward) + chunked_rid = self.chunked_req.rid + self.stash_chunked_request(self.chunked_req) + chunked_reqs: List[Request] = [] + to_forward_reqs: List[Request] = [] + for req in base_to_forward: + if req.request_id == chunked_rid: + req.status = RequestStatus.PREFILLING + chunked_reqs.append(req) + else: + to_forward_reqs.append(req) + return (chunked_reqs, to_forward_reqs) + def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: @@ -361,7 +399,6 @@ def handle_input_requests(self, requests: List[Request]): f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " f"output_ids now has {len(original_req.output_ids)} tokens" ) - if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -380,7 +417,7 @@ def handle_input_requests(self, requests: List[Request]): self.chunked_req.is_chunked -= 1 self.scheduler.enque_request(original_req) continue - + elif self.scheduler.check_and_update_request_status(original_req): logger.debug(f"Releasing resources for finished request {req.request_id}") self.release_and_evict_request(req.request_id) @@ -427,15 +464,6 @@ def handle_input_requests(self, requests: List[Request]): else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) - - def stash_chunked_request(self, req: Req): - self.page_tree_cache.cache_unfinished_req(req, chunked=True) - # FIX: below code don't now if it is needed - # # Chunked request keeps its rid but will get a new req_pool_idx - # if self.model_runner.mambaish_config is not None: - # self.model_runner.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=False) - # else: - # self.model_runner.req_to_token_pool.free(req.req_pool_idx) def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): """Process a batch of requests in SGLang.""" @@ -454,22 +482,19 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: pp_proxy_tensors=pp_proxy_tensors, ) logits_output = out.logits_output - # Merge prefill batch into running batch chunked_req_to_exclude = set() - + if self.chunked_req is not None: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. chunked_req_to_exclude.add(self.chunked_req) - self.stash_chunked_request(self.chunked_req) - - + if self.cur_batch and self.cur_batch.forward_mode.is_extend(): if self.cur_batch.chunked_req is not None: chunked_req_to_exclude.add(self.cur_batch.chunked_req) - + if self.cur_batch: # Avoid duplicate rids in running_batch: exclude any cur_batch req # whose rid is already in running_batch (filter_batch uses object identity). @@ -479,10 +504,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if req.rid in existing_rids: chunked_req_to_exclude.add(req) cur_batch_size = self.cur_batch.batch_size() - self.cur_batch.filter_batch( - chunked_req_to_exclude=list[Req](chunked_req_to_exclude) - ) - + self.cur_batch.filter_batch(chunked_req_to_exclude=list[Req](chunked_req_to_exclude)) + if self.cur_batch.batch_size() < cur_batch_size: self.cur_batch.batch_is_full = False if self.cur_batch.forward_mode.is_extend(): @@ -499,9 +522,13 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if return_decoded_tokens: # Debug: log running_batch vs prepared_inputs["requests"] to detect # "running_batch still has previous request" causing token mis-assignment - running_batch_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + running_batch_size = ( + 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) + ) running_batch_rids = ( - [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + [] + if self.running_batch.is_empty() + else [req.rid for req in self.running_batch.reqs] ) requests_len = len(requests) requests_ids = [getattr(r, "request_id", str(r)) for r in requests] @@ -564,7 +591,9 @@ def _release_request(self, rid: str): # Debug: log running_batch before release to verify request eviction before_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) before_rids = ( - [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + [] + if self.running_batch.is_empty() + else [req.rid for req in self.running_batch.reqs] ) logger.debug( "[ChunkedPrefill-Debug] _release_request BEFORE: releasing rid=%s, " @@ -576,7 +605,9 @@ def _release_request(self, rid: str): release_sglang_request(self.running_batch, rid) after_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) after_rids = ( - [] if self.running_batch.is_empty() else [req.rid for req in self.running_batch.reqs] + [] + if self.running_batch.is_empty() + else [req.rid for req in self.running_batch.reqs] ) logger.debug( "[ChunkedPrefill-Debug] _release_request AFTER: running_batch size=%s rids=%s", diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 097fadca..e12f778f 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -108,7 +108,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--enable-prefix-cache", action="store_true", help="Enable prefix cache reuse" ) - + # add --chunked-prefill-size parser.add_argument( "--chunked-prefill-size", diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index bb17a2be..423114e6 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -8,11 +8,11 @@ from __future__ import annotations from types import SimpleNamespace -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, List -from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner @@ -53,18 +53,35 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S return params +def _dummy_tree_cache_for_adder(): + """Dummy tree cache when prefix cache is disabled. PrefillAdder expects tree_cache + to have evictable_size(), inc_lock_ref(), dec_lock_ref() etc.; pass this instead of None.""" + return SimpleNamespace( + evictable_size=lambda: 0, + full_evictable_size=lambda: 0, + swa_evictable_size=lambda: 0, + inc_lock_ref=lambda node: None, + dec_lock_ref=lambda node, uuid=None: None, + ) + + def transform_requests_to_sglang( - old_requests: List[Request], + old_requests: List[Request], executor: SGLExecutor, ) -> List[Req]: """Transforms Parallax Request to SGLang.Req format""" model_runner = executor.model_runner page_tree_cache = executor.page_tree_cache chunked_prefill_size = executor.chunked_prefill_size + # When prefix cache is disabled, sglang's PrefillAdder still expects tree_cache to have + # evictable_size(), inc_lock_ref(), etc. Pass a dummy instead of None. + tree_cache_for_adder = ( + page_tree_cache if page_tree_cache is not None else _dummy_tree_cache_for_adder() + ) # Prefill policy adder = PrefillAdder( model_runner.page_size, - page_tree_cache, + tree_cache_for_adder, model_runner.token_to_kv_pool_allocator, None, None, @@ -74,15 +91,16 @@ def transform_requests_to_sglang( None, None, ) - + # Save rid of chunked req added this round so we can reorder can_run_list to match old_requests. + chunked_rid = executor.chunked_req.rid if executor.chunked_req is not None else None + if executor.chunked_req is not None: logger.debug(f"before add_chunked_req, chunked_req is not None") executor.chunked_req.init_next_round_input(page_tree_cache) executor.chunked_req = adder.add_chunked_req(executor.chunked_req) if executor.chunked_req is None: logger.debug(f"after add_chunked_req, chunked_req is None") - - + reqs = [] logger.debug(f"old_req size: {len(old_requests)}") for old_req in old_requests: @@ -107,17 +125,21 @@ def transform_requests_to_sglang( ) req.init_next_round_input(page_tree_cache) - + res = adder.add_one_req( - req, executor.chunked_req is not None, None, + req, + executor.chunked_req is not None, + None, ) - + if res != AddReqResult.CONTINUE: - logger.warning(f"Request {old_req.request_id} failed to add to prefill batch, result: {res},\ - req_len: {len(req.origin_input_ids)}") + logger.warning( + f"Request {old_req.request_id} failed to add to prefill batch, result: {res},\ + req_len: {len(req.origin_input_ids)}" + ) if res == AddReqResult.NO_TOKEN: logger.warning(f"there is no token to add to prefill batch") - executor.running_batch.batch_is_full = True + executor.running_batch.batch_is_full = True break # Debug: Log after cache lookup @@ -133,9 +155,9 @@ def transform_requests_to_sglang( ) reqs.append(req) - + logger.debug(f"new reqs size: {len(reqs)}") - + if adder.new_chunked_req is not None: # Update chunked prefill assert executor.chunked_req is None @@ -144,8 +166,16 @@ def transform_requests_to_sglang( if executor.chunked_req is not None: executor.chunked_req.is_chunked += 1 - - return adder.can_run_list + + # Reorder so returned list follows old_requests order and each element is the Req + # that corresponds to that old_req (same rid). Use rid to map instead of assuming + # can_run_list order, so the relationship with reqs is explicit. + can_run_list = adder.can_run_list + if chunked_rid is None: + return can_run_list + rid_to_req = {req.rid: req for req in can_run_list} + reordered: List[Req] = [rid_to_req[old_req.request_id] for old_req in old_requests] + return reordered def form_sgl_batch_prefill( @@ -155,7 +185,6 @@ def form_sgl_batch_prefill( """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" model_runner = executor.model_runner page_tree_cache = executor.page_tree_cache - sgl_reqs = transform_requests_to_sglang(requests, executor) @@ -180,6 +209,10 @@ def dummy_evict(*args): chunked_req=executor.chunked_req, ) schedule_batch.prepare_for_extend() + if executor.chunked_req is not None: + logger.debug( + f"chunked_req.rid={executor.chunked_req.rid}, chunked_req.req_pool_idx: {executor.chunked_req.req_pool_idx}" + ) num_tokens = schedule_batch.extend_num_tokens dp_size = model_runner.dp_size From 1ed5d675ef4923cdcf101e48cea3ae614c6d8c87 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Fri, 30 Jan 2026 09:16:08 +0000 Subject: [PATCH 06/50] fix[chunked-prefill]:transfer all chunks instead of last chunk --- .../server/executor/sglang_executor.py | 18 ++++++++++++++++++ tests/test_executor.py | 12 ++++++++---- tests/test_server_args.py | 2 ++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index dbe64dfd..5f6d8514 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -198,6 +198,9 @@ def __init__( if chunked_prefill_size is not None and chunked_prefill_size <= 0: chunked_prefill_size = None self.chunked_req = None + # Per-request accumulator for chunked prefill hidden states (non-last peers only). + # When forwarding after all chunks are done, replace with concat of all chunks. + self._chunked_prefill_hidden_accumulator: Dict[str, List[torch.Tensor]] = {} # create a page tree cache for sglang prefill if enable_prefix_cache: @@ -344,6 +347,15 @@ def prepare_next_batch_requests( requests, batch_output, context_lengths ) if self.chunked_req is None or self.chunked_req.is_chunked <= 0: + # Chunked prefill just finished (or no chunked): replace hidden_states with full concat when forwarding. + if not self.is_last_peer and self._chunked_prefill_hidden_accumulator: + for req in base_to_forward: + rid = req.request_id + if rid in self._chunked_prefill_hidden_accumulator: + chunks = self._chunked_prefill_hidden_accumulator[rid] + chunks.append(req.hidden_states) + req.hidden_states = torch.cat(chunks, dim=0) + del self._chunked_prefill_hidden_accumulator[rid] return (base_chunked, base_to_forward) chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) @@ -353,6 +365,11 @@ def prepare_next_batch_requests( if req.request_id == chunked_rid: req.status = RequestStatus.PREFILLING chunked_reqs.append(req) + # Accumulate this chunk's hidden_states for full prefill when forwarding (non-last peers only). + if not self.is_last_peer: + self._chunked_prefill_hidden_accumulator.setdefault(req.request_id, []).append( + req.hidden_states + ) else: to_forward_reqs.append(req) return (chunked_reqs, to_forward_reqs) @@ -588,6 +605,7 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: def _release_request(self, rid: str): """Release per-request resources in SGLang.""" try: + self._chunked_prefill_hidden_accumulator.pop(rid, None) # Debug: log running_batch before release to verify request eviction before_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) before_rids = ( diff --git a/tests/test_executor.py b/tests/test_executor.py index e20d4174..3586d8e2 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -65,7 +65,9 @@ def create_executor(start_layer, end_layer, device, kv_cache_memory_fraction=0.3 def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): - """Run executor pipeline stage. Input and output should be requests""" + """Run executor pipeline stage. Input and output should be requests. + Returns (to_forward_reqs, batch_output); chunked_reqs are handled locally via handle_input_requests. + """ executor.handle_input_requests(requests) executor.scheduler.admit_requests() input_batch = executor.scheduler.form_batch() @@ -73,12 +75,14 @@ def run_executor_pipeline_stage(executor, requests, batch_type, is_last_peer): assert prepared_batch is not None, "Failed to prepare batch inputs" batch_data = prepared_batch[batch_type] batch_output = executor.process_batch(batch_data, return_decoded_tokens=is_last_peer) - output_reqs = executor.prepare_next_batch_requests( + chunked_reqs, to_forward_reqs = executor.prepare_next_batch_requests( requests=batch_data["requests"], batch_output=batch_output, context_lengths=batch_data.get("context_lengths"), ) - return output_reqs, batch_output + if chunked_reqs: + executor.handle_input_requests(chunked_reqs) + return to_forward_reqs, batch_output @pytest.mark.parametrize( @@ -119,8 +123,8 @@ def test_decode_pipeline_multiple_steps(pipeline_devices, pp_end_layers, num_dec ref_cuda_model = AutoModelForCausalLM.from_pretrained( CUDA_MODEL_REPO, torch_dtype=torch.bfloat16, - device_map="cuda:0", ) + ref_cuda_model = ref_cuda_model.to("cuda:0") ref_cuda_tokenizer = AutoTokenizer.from_pretrained(CUDA_MODEL_REPO) if ref_cuda_tokenizer.pad_token is None: ref_cuda_tokenizer.pad_token = ref_cuda_tokenizer.eos_token diff --git a/tests/test_server_args.py b/tests/test_server_args.py index 3564dabd..19332129 100644 --- a/tests/test_server_args.py +++ b/tests/test_server_args.py @@ -143,6 +143,8 @@ def test_create_config(self): max_lora_chunk_size=128, enable_weight_refit=False, weight_refit_mode="cpu", + chunked_prefill_size=128, + max_prefill_tokens=1024, ) config = create_executor_config(args) From 876f13126b9ca3ba1d9459aec202288e2a3f154a Mon Sep 17 00:00:00 2001 From: wasamtc Date: Fri, 30 Jan 2026 09:42:00 +0000 Subject: [PATCH 07/50] fix[chunked-prefill]:restore last chunk --- .../server/executor/sglang_executor.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 5f6d8514..dbe64dfd 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -198,9 +198,6 @@ def __init__( if chunked_prefill_size is not None and chunked_prefill_size <= 0: chunked_prefill_size = None self.chunked_req = None - # Per-request accumulator for chunked prefill hidden states (non-last peers only). - # When forwarding after all chunks are done, replace with concat of all chunks. - self._chunked_prefill_hidden_accumulator: Dict[str, List[torch.Tensor]] = {} # create a page tree cache for sglang prefill if enable_prefix_cache: @@ -347,15 +344,6 @@ def prepare_next_batch_requests( requests, batch_output, context_lengths ) if self.chunked_req is None or self.chunked_req.is_chunked <= 0: - # Chunked prefill just finished (or no chunked): replace hidden_states with full concat when forwarding. - if not self.is_last_peer and self._chunked_prefill_hidden_accumulator: - for req in base_to_forward: - rid = req.request_id - if rid in self._chunked_prefill_hidden_accumulator: - chunks = self._chunked_prefill_hidden_accumulator[rid] - chunks.append(req.hidden_states) - req.hidden_states = torch.cat(chunks, dim=0) - del self._chunked_prefill_hidden_accumulator[rid] return (base_chunked, base_to_forward) chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) @@ -365,11 +353,6 @@ def prepare_next_batch_requests( if req.request_id == chunked_rid: req.status = RequestStatus.PREFILLING chunked_reqs.append(req) - # Accumulate this chunk's hidden_states for full prefill when forwarding (non-last peers only). - if not self.is_last_peer: - self._chunked_prefill_hidden_accumulator.setdefault(req.request_id, []).append( - req.hidden_states - ) else: to_forward_reqs.append(req) return (chunked_reqs, to_forward_reqs) @@ -605,7 +588,6 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: def _release_request(self, rid: str): """Release per-request resources in SGLang.""" try: - self._chunked_prefill_hidden_accumulator.pop(rid, None) # Debug: log running_batch before release to verify request eviction before_size = 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) before_rids = ( From fd6370698fafd6b61874fb7af3fcf588761d8a73 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 09:18:16 +0000 Subject: [PATCH 08/50] feat[chunked-prefill]: base multi-node chunked prefill for gpu --- src/parallax/server/executor/base_executor.py | 19 ++++-------- .../server/executor/sglang_executor.py | 30 +++++++++---------- src/parallax/sglang/batch_info.py | 10 +++---- 3 files changed, 25 insertions(+), 34 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index eaaa7441..257c0a51 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -386,13 +386,10 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> Tuple[List[Request], List[Request]]: + ) -> List[Request]: """Prepares a batch of requests for the next stage of the pipeline. - Returns two lists: - - chunked_reqs: requests that still have chunks to complete (e.g. chunked prefill); - these should be handled locally via handle_input_requests only, not sent to next peer. - - to_forward_reqs: requests to forward (single node: handle_input_requests; + Returns a list of requests to forward (single node: handle_input_requests; multi-node and tp_rank==0: request_to_proto and send to next peer). Args: @@ -445,10 +442,7 @@ def prepare_next_batch_requests( ) batched_requests.append(next_req) - # Default: no chunked reqs; all go to to_forward. - chunked_reqs: List[Request] = [] - to_forward_reqs: List[Request] = batched_requests - return (chunked_reqs, to_forward_reqs) + return batched_requests def release_and_evict_request(self, rid: str): """Release per-request resources and evict from scheduler. Best-effort, never raises.""" @@ -556,16 +550,13 @@ def run_loop(self): except Exception: pass # 7. Prepare requests for the next stage in the pipeline - chunked_reqs, to_forward_reqs = self.prepare_next_batch_requests( + to_forward_reqs = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) # 8. Dispatch to the appropriate destination - # Chunked reqs (e.g. chunked prefill not yet done) stay local. - if chunked_reqs: - self.handle_input_requests(chunked_reqs) if to_forward_reqs: if self.is_last_peer and self.is_first_peer: # Single node: handle to_forward locally @@ -582,7 +573,7 @@ def run_loop(self): ) logger.debug( f"Processed batch of type {batch_type} with {len(to_forward_reqs)} to_forward " - f"(chunked={len(chunked_reqs)}) in {(time.time() - start_time) * 1000:.3f} ms" + f"in {(time.time() - start_time) * 1000:.3f} ms" ) except Exception as e: diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index dbe64dfd..9346f01e 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -338,24 +338,19 @@ def stash_chunked_request(self, req: Req): def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> Tuple[List[Request], List[Request]]: + ) -> List[Request]: """Split out chunked prefill reqs; set their status to PREFILLING and return (chunked, to_forward).""" - base_chunked, base_to_forward = super().prepare_next_batch_requests( + base_to_forward = super().prepare_next_batch_requests( requests, batch_output, context_lengths ) - if self.chunked_req is None or self.chunked_req.is_chunked <= 0: - return (base_chunked, base_to_forward) + if self.chunked_req is None or self.chunked_req.is_chunked <= 0 or self.chunked_req.rid not in [req.request_id for req in requests]: + return base_to_forward chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) - chunked_reqs: List[Request] = [] - to_forward_reqs: List[Request] = [] - for req in base_to_forward: - if req.request_id == chunked_rid: - req.status = RequestStatus.PREFILLING - chunked_reqs.append(req) - else: - to_forward_reqs.append(req) - return (chunked_reqs, to_forward_reqs) + # delete chunked_req from base_to_forward if self is last_peer + if self.is_last_peer: + base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + return base_to_forward def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" @@ -701,7 +696,12 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A return None # Pre-check: Verify KV cache has enough space for prefill - total_tokens_needed = sum(req.total_length for req in batched_requests) + def _get_total_length(req: Request) -> int: + if hasattr(req, "hidden_states") and req.hidden_states is not None: + return req.hidden_states.shape[0] + return req.total_length + + total_tokens_needed = sum(_get_total_length(req) for req in batched_requests) if not self._check_kv_cache_available(total_tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, @@ -751,7 +751,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A if self.lora_paths and len(self.lora_paths) > 0 else None ) - lengths.append(req.total_length) + lengths.append(_get_total_length(req)) lengths_tensor = torch.tensor(lengths, device=self.device) schedule_batch, forward_batch = form_sgl_batch_prefill( diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 423114e6..ade55b88 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -94,7 +94,7 @@ def transform_requests_to_sglang( # Save rid of chunked req added this round so we can reorder can_run_list to match old_requests. chunked_rid = executor.chunked_req.rid if executor.chunked_req is not None else None - if executor.chunked_req is not None: + if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in old_requests]: logger.debug(f"before add_chunked_req, chunked_req is not None") executor.chunked_req.init_next_round_input(page_tree_cache) executor.chunked_req = adder.add_chunked_req(executor.chunked_req) @@ -164,14 +164,14 @@ def transform_requests_to_sglang( executor.chunked_req = adder.new_chunked_req logger.debug(f"new chunked_req is {executor.chunked_req}") - if executor.chunked_req is not None: + if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in old_requests]: executor.chunked_req.is_chunked += 1 # Reorder so returned list follows old_requests order and each element is the Req # that corresponds to that old_req (same rid). Use rid to map instead of assuming # can_run_list order, so the relationship with reqs is explicit. can_run_list = adder.can_run_list - if chunked_rid is None: + if chunked_rid is None or executor.chunked_req.rid not in [req.request_id for req in old_requests]: return can_run_list rid_to_req = {req.rid: req for req in can_run_list} reordered: List[Req] = [rid_to_req[old_req.request_id] for old_req in old_requests] @@ -206,10 +206,10 @@ def dummy_evict(*args): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, - chunked_req=executor.chunked_req, + chunked_req=executor.chunked_req if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in requests] else None, ) schedule_batch.prepare_for_extend() - if executor.chunked_req is not None: + if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in requests]: logger.debug( f"chunked_req.rid={executor.chunked_req.rid}, chunked_req.req_pool_idx: {executor.chunked_req.req_pool_idx}" ) From 13c6773559e6f7c7ecb92073d0cc6e2671460445 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 09:30:15 +0000 Subject: [PATCH 09/50] fix[chunked-prefill]: chunked_req should be send to local and next --- src/parallax/server/executor/base_executor.py | 23 +++++++++++-------- .../server/executor/sglang_executor.py | 9 ++++---- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 257c0a51..0ae2ce8d 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -386,10 +386,13 @@ def prepare_batch_inputs(self, batched_requests: List[Request]) -> Optional[Dict def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> List[Request]: + ) -> Tuple[List[Request], List[Request]]: """Prepares a batch of requests for the next stage of the pipeline. - Returns a list of requests to forward (single node: handle_input_requests; + Returns two lists: + - chunked_reqs: requests that still have chunks to complete (e.g. chunked prefill); + these should be handled locally via handle_input_requests only, not sent to next peer. + - to_forward_reqs: requests to forward (single node: handle_input_requests; multi-node and tp_rank==0: request_to_proto and send to next peer). Args: @@ -406,7 +409,8 @@ def prepare_next_batch_requests( hidden_states = batch_output["hidden_states"] token_probs = batch_output["probs"] - batched_requests = [] + chunked_reqs = [] + to_forward_reqs = [] pre_length = 0 for i, src_request in enumerate(requests): if self.is_last_peer: @@ -440,9 +444,9 @@ def prepare_next_batch_requests( next_req = self._prepare_next_single_request( src_request, hidden_state_for_req, token_prob ) - batched_requests.append(next_req) + to_forward_reqs.append(next_req) - return batched_requests + return chunked_reqs, to_forward_reqs def release_and_evict_request(self, rid: str): """Release per-request resources and evict from scheduler. Best-effort, never raises.""" @@ -550,12 +554,13 @@ def run_loop(self): except Exception: pass # 7. Prepare requests for the next stage in the pipeline - to_forward_reqs = self.prepare_next_batch_requests( + chunked_reqs, to_forward_reqs = self.prepare_next_batch_requests( requests=prepared_inputs["requests"], batch_output=output, context_lengths=prepared_inputs.get("context_lengths"), ) - + if chunked_reqs: + self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination if to_forward_reqs: if self.is_last_peer and self.is_first_peer: @@ -567,12 +572,12 @@ def run_loop(self): [ b"forward", request_to_proto( - to_forward_reqs, self.device + to_forward_reqs + chunked_reqs, self.device ).SerializeToString(), ] ) logger.debug( - f"Processed batch of type {batch_type} with {len(to_forward_reqs)} to_forward " + f"Processed batch of type {batch_type} with {len(to_forward_reqs + chunked_reqs)} to_forward " f"in {(time.time() - start_time) * 1000:.3f} ms" ) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 9346f01e..12ea64e4 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -338,19 +338,20 @@ def stash_chunked_request(self, req: Req): def prepare_next_batch_requests( self, requests: List[Request], batch_output: Any, context_lengths: Any - ) -> List[Request]: + ) -> Tuple[List[Request], List[Request]]: """Split out chunked prefill reqs; set their status to PREFILLING and return (chunked, to_forward).""" - base_to_forward = super().prepare_next_batch_requests( + base_chunked, base_to_forward = super().prepare_next_batch_requests( requests, batch_output, context_lengths ) if self.chunked_req is None or self.chunked_req.is_chunked <= 0 or self.chunked_req.rid not in [req.request_id for req in requests]: - return base_to_forward + return base_chunked, base_to_forward chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) + base_chunked.append(self.chunked_req) # delete chunked_req from base_to_forward if self is last_peer if self.is_last_peer: base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] - return base_to_forward + return base_chunked, base_to_forward def handle_input_requests(self, requests: List[Request]): """Update requests states and status in scheduler and cache manager.""" From 3d9f72f6d9d4e2288c683b9ac78eeb0e48862b21 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 09:38:36 +0000 Subject: [PATCH 10/50] fix[chunked-prefill]: chunked_req should add it is origin req to base_chunks --- src/parallax/server/executor/sglang_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 12ea64e4..0f0f3f06 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -347,7 +347,10 @@ def prepare_next_batch_requests( return base_chunked, base_to_forward chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) - base_chunked.append(self.chunked_req) + for req in base_to_forward: + if req.request_id == chunked_rid: + base_chunked.append(req) + break # delete chunked_req from base_to_forward if self is last_peer if self.is_last_peer: base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] From b232959f02dd8df220506abf4e3510467c180b7c Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 09:45:55 +0000 Subject: [PATCH 11/50] fix[chunked-prefill]: return can_run_list if not reorder reqs --- src/parallax/sglang/batch_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index ade55b88..29678e10 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -171,7 +171,7 @@ def transform_requests_to_sglang( # that corresponds to that old_req (same rid). Use rid to map instead of assuming # can_run_list order, so the relationship with reqs is explicit. can_run_list = adder.can_run_list - if chunked_rid is None or executor.chunked_req.rid not in [req.request_id for req in old_requests]: + if chunked_rid is None and (executor.chunked_req is None or executor.chunked_req.rid not in [req.request_id for req in old_requests]): return can_run_list rid_to_req = {req.rid: req for req in can_run_list} reordered: List[Req] = [rid_to_req[old_req.request_id] for old_req in old_requests] From 4a14ecbb8d2e98997abcaaa603425337d9329039 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 10:02:19 +0000 Subject: [PATCH 12/50] fix[chunked-prefill]: add log --- src/parallax/server/executor/sglang_executor.py | 4 ++-- src/parallax/sglang/batch_info.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 0f0f3f06..109e5e8c 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -485,13 +485,13 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Merge prefill batch into running batch chunked_req_to_exclude = set() - if self.chunked_req is not None: + if self.chunked_req is not None and self.chunked_req.is_chunked > 0: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. chunked_req_to_exclude.add(self.chunked_req) if self.cur_batch and self.cur_batch.forward_mode.is_extend(): - if self.cur_batch.chunked_req is not None: + if self.cur_batch.chunked_req is not None and self.cur_batch.chunked_req.is_chunked > 0: chunked_req_to_exclude.add(self.cur_batch.chunked_req) if self.cur_batch: diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 29678e10..1bc233cd 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -99,7 +99,7 @@ def transform_requests_to_sglang( executor.chunked_req.init_next_round_input(page_tree_cache) executor.chunked_req = adder.add_chunked_req(executor.chunked_req) if executor.chunked_req is None: - logger.debug(f"after add_chunked_req, chunked_req is None") + logger.debug(f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None") reqs = [] logger.debug(f"old_req size: {len(old_requests)}") From c98594f361f07413c304c57dfd91d2cb664b2301 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 10:11:35 +0000 Subject: [PATCH 13/50] fix[chunked-prefill]: add log --- src/parallax/server/executor/sglang_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 109e5e8c..97e5f6d4 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -344,6 +344,7 @@ def prepare_next_batch_requests( requests, batch_output, context_lengths ) if self.chunked_req is None or self.chunked_req.is_chunked <= 0 or self.chunked_req.rid not in [req.request_id for req in requests]: + logger.debug(f"sglang_executor: prepare_next_batch_requests: return base_chunked and base_to_forward because chunked_req is None or is_chunked <= 0 or rid not in requests") return base_chunked, base_to_forward chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) @@ -354,6 +355,7 @@ def prepare_next_batch_requests( # delete chunked_req from base_to_forward if self is last_peer if self.is_last_peer: base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + logger.debug(f"sglang_executor: prepare_next_batch_requests: return new_chunked and new_to_forward because chunked_req is not None and is_chunked > 0 and rid in requests") return base_chunked, base_to_forward def handle_input_requests(self, requests: List[Request]): From ffce612ecfa89bacbee3282769aaf6dadfd135a3 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 10:26:10 +0000 Subject: [PATCH 14/50] fix[chunked-prefill]: delete all chunked_reqs from base_forward because send data use chunked_reqs+forward_reqs --- src/parallax/server/executor/sglang_executor.py | 6 ++---- src/parallax/sglang/batch_info.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 97e5f6d4..3bc62faa 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -352,10 +352,8 @@ def prepare_next_batch_requests( if req.request_id == chunked_rid: base_chunked.append(req) break - # delete chunked_req from base_to_forward if self is last_peer - if self.is_last_peer: - base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] - logger.debug(f"sglang_executor: prepare_next_batch_requests: return new_chunked and new_to_forward because chunked_req is not None and is_chunked > 0 and rid in requests") + base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + logger.debug(f"sglang_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests") return base_chunked, base_to_forward def handle_input_requests(self, requests: List[Request]): diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 1bc233cd..24dd4381 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -105,7 +105,7 @@ def transform_requests_to_sglang( logger.debug(f"old_req size: {len(old_requests)}") for old_req in old_requests: # Chunked req is added via add_chunked_req above; skip to avoid double-add. - if executor.chunked_req is not None and old_req.request_id == executor.chunked_req.rid: + if chunked_rid is not None and old_req.request_id == chunked_rid: continue sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( From b1cbd8ccb2ae78d9359bac3d8a40da1d286664e6 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sat, 31 Jan 2026 10:31:41 +0000 Subject: [PATCH 15/50] fix[chunked-prefill]: change base_executor run_loop last judge condition --- src/parallax/server/executor/base_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 0ae2ce8d..119f4554 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -562,7 +562,7 @@ def run_loop(self): if chunked_reqs: self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination - if to_forward_reqs: + if to_forward_reqs or chunked_reqs: if self.is_last_peer and self.is_first_peer: # Single node: handle to_forward locally self.handle_input_requests(to_forward_reqs) From b6bd4d3e359b30c9cc1f908b37766f91efa281cd Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 00:51:15 +0000 Subject: [PATCH 16/50] fix[chunked-prefill]: add log --- src/parallax/server/executor/sglang_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 3bc62faa..22376484 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -488,10 +488,12 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if self.chunked_req is not None and self.chunked_req.is_chunked > 0: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. + logger.debug(f"exclude chunked_req {self.chunked_req.rid} from running_batch") chunked_req_to_exclude.add(self.chunked_req) if self.cur_batch and self.cur_batch.forward_mode.is_extend(): if self.cur_batch.chunked_req is not None and self.cur_batch.chunked_req.is_chunked > 0: + logger.debug(f"exclude chunked_req {self.cur_batch.chunked_req.rid} from running_batch") chunked_req_to_exclude.add(self.cur_batch.chunked_req) if self.cur_batch: From 185c184a3f82f685924cc6b2142077444c093877 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:08:50 +0000 Subject: [PATCH 17/50] fix[chunked-prefill]: change last peer just send to_forward_reqs --- src/parallax/server/executor/base_executor.py | 38 ++++++++++++------- .../server/executor/sglang_executor.py | 1 + 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 119f4554..c8cc4d7a 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -567,19 +567,31 @@ def run_loop(self): # Single node: handle to_forward locally self.handle_input_requests(to_forward_reqs) elif self.tp_rank == 0: - # Send to_forward to next peer (do not send chunked_reqs) - self.send_to_peer_socket.send_multipart( - [ - b"forward", - request_to_proto( - to_forward_reqs + chunked_reqs, self.device - ).SerializeToString(), - ] - ) - logger.debug( - f"Processed batch of type {batch_type} with {len(to_forward_reqs + chunked_reqs)} to_forward " - f"in {(time.time() - start_time) * 1000:.3f} ms" - ) + # Send to_forward to next peer (do not send chunked_reqs if self is last_peer) + if not self.is_last_peer: + self.send_to_peer_socket.send_multipart( + [ + b"forward", + request_to_proto( + to_forward_reqs + chunked_reqs, self.device + ).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of with {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs " + f"in {(time.time() - start_time) * 1000:.3f} ms" + ) + else: + self.send_to_peer_socket.send_multipart( + [ + b"forward_chunked", + request_to_proto(to_forward_reqs, self.device).SerializeToString(), + ] + ) + logger.debug( + f"Processed batch of with {len(to_forward_reqs)} to_forward " + f"in {(time.time() - start_time) * 1000:.3f} ms" + ) except Exception as e: logger.exception(f"Error processing batch: {e}") diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 22376484..906f14bd 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -350,6 +350,7 @@ def prepare_next_batch_requests( self.stash_chunked_request(self.chunked_req) for req in base_to_forward: if req.request_id == chunked_rid: + req.status = RequestStatus.PREFILLING base_chunked.append(req) break base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] From 6d8268e537d08d8547b359907628e8ab4689059e Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:11:56 +0000 Subject: [PATCH 18/50] fix[chunked-prefill]: change message type bug --- src/parallax/server/executor/base_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index c8cc4d7a..1f6a3310 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -584,7 +584,7 @@ def run_loop(self): else: self.send_to_peer_socket.send_multipart( [ - b"forward_chunked", + b"forward", request_to_proto(to_forward_reqs, self.device).SerializeToString(), ] ) From 0d496121286c50276be9e60333934bc7cba5aaac Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:19:18 +0000 Subject: [PATCH 19/50] fix[chunked-prefill]: last send forward_reqs when forward_reqs is not none --- src/parallax/server/executor/base_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 1f6a3310..4ae5a948 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -581,7 +581,7 @@ def run_loop(self): f"Processed batch of with {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs " f"in {(time.time() - start_time) * 1000:.3f} ms" ) - else: + elif to_forward_reqs is not None: self.send_to_peer_socket.send_multipart( [ b"forward", From 5a6936ff5c35b44aed2354bfd131674d79b85072 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:22:15 +0000 Subject: [PATCH 20/50] fix[chunked-prefill]: last send forward_reqs when forward_reqs is not none --- src/parallax/server/executor/base_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 4ae5a948..4ff0700e 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -563,7 +563,7 @@ def run_loop(self): self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination if to_forward_reqs or chunked_reqs: - if self.is_last_peer and self.is_first_peer: + if self.is_last_peer and self.is_first_peer and (to_forward_reqs is not None and len(to_forward_reqs) > 0): # Single node: handle to_forward locally self.handle_input_requests(to_forward_reqs) elif self.tp_rank == 0: @@ -581,7 +581,7 @@ def run_loop(self): f"Processed batch of with {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs " f"in {(time.time() - start_time) * 1000:.3f} ms" ) - elif to_forward_reqs is not None: + elif to_forward_reqs is not None and len(to_forward_reqs) > 0: self.send_to_peer_socket.send_multipart( [ b"forward", From 8318548408d00cac42c2efd41ea11c0b059c058a Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:32:06 +0000 Subject: [PATCH 21/50] fix[chunked-prefill]: handle_reqs should do other if not last_peer --- src/parallax/server/executor/sglang_executor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 906f14bd..4e791c82 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -461,6 +461,10 @@ def handle_input_requests(self, requests: List[Request]): self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) + elif self.chunked_req is not None and self.chunked_req.rid == req.request_id and self.chunked_req.is_chunked > 0: + self.chunked_req.is_chunked -= 1 + req.status = RequestStatus.PREFILLING + continue else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) From 41536114eb75f2be1c8df78a9e5a001192d846d1 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:38:05 +0000 Subject: [PATCH 22/50] fix[chunked-prefill]: add logs --- src/parallax/server/executor/base_executor.py | 4 +++- src/parallax/server/scheduler.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 4ff0700e..23a7547f 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -466,6 +466,7 @@ def run_loop(self): ) self._should_stop = False while not self._should_stop: + logger.debug(f"Executor for layers [{self.start_layer}, {self.end_layer}) running loop...") received_requests = [] # Receive requests from http frontend @@ -516,9 +517,10 @@ def run_loop(self): self.finished_batch.append(req) except Exception: # Non-fatal; continue serving - pass + pass batch_to_process = self.scheduler.form_batch() if not batch_to_process: + logger.debug(f"No batch to process. continue to next loop...") continue logger.debug(f"Formed batch with {len(batch_to_process)} requests.") diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index e70c11f4..a70ecd11 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -239,6 +239,7 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: + logger.debug(f"Request {rid} already in running requests. skip admit.") continue # Check kv cache pool @@ -311,6 +312,7 @@ def form_batch(self) -> List[Request]: """ self.admit_requests() if not self._running_requests: + logger.debug(f"No running requests to form batch.") return [] inflight_tokens = 0 From 484d66fc0cfacdc28074cf4ff1ba59070c104e00 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:44:58 +0000 Subject: [PATCH 23/50] fix[chunked-prefill]: add logs --- src/parallax/server/executor/base_executor.py | 8 +++++--- src/parallax/server/scheduler.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 23a7547f..7209bfc1 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -466,7 +466,6 @@ def run_loop(self): ) self._should_stop = False while not self._should_stop: - logger.debug(f"Executor for layers [{self.start_layer}, {self.end_layer}) running loop...") received_requests = [] # Receive requests from http frontend @@ -517,10 +516,9 @@ def run_loop(self): self.finished_batch.append(req) except Exception: # Non-fatal; continue serving - pass + pass batch_to_process = self.scheduler.form_batch() if not batch_to_process: - logger.debug(f"No batch to process. continue to next loop...") continue logger.debug(f"Formed batch with {len(batch_to_process)} requests.") @@ -562,15 +560,18 @@ def run_loop(self): context_lengths=prepared_inputs.get("context_lengths"), ) if chunked_reqs: + logger.debug(f"Handle {len(chunked_reqs)} chunked requests.") self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination if to_forward_reqs or chunked_reqs: if self.is_last_peer and self.is_first_peer and (to_forward_reqs is not None and len(to_forward_reqs) > 0): # Single node: handle to_forward locally + logger.debug(f"Handle {len(to_forward_reqs)} to_forward requests.") self.handle_input_requests(to_forward_reqs) elif self.tp_rank == 0: # Send to_forward to next peer (do not send chunked_reqs if self is last_peer) if not self.is_last_peer: + logger.debug(f"Send {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs to next peer.") self.send_to_peer_socket.send_multipart( [ b"forward", @@ -584,6 +585,7 @@ def run_loop(self): f"in {(time.time() - start_time) * 1000:.3f} ms" ) elif to_forward_reqs is not None and len(to_forward_reqs) > 0: + logger.debug(f"Send {len(to_forward_reqs)} to_forward to next peer.") self.send_to_peer_socket.send_multipart( [ b"forward", diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index a70ecd11..f11b4442 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -239,7 +239,6 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: - logger.debug(f"Request {rid} already in running requests. skip admit.") continue # Check kv cache pool From ecdbd0a65897eae46c7d72210542a03ac37d0aea Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:48:28 +0000 Subject: [PATCH 24/50] fix[chunked-prefill]: add logs --- src/parallax/server/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index f11b4442..e70c11f4 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -311,7 +311,6 @@ def form_batch(self) -> List[Request]: """ self.admit_requests() if not self._running_requests: - logger.debug(f"No running requests to form batch.") return [] inflight_tokens = 0 From a4ebbd46c3e378f8c09300582cebc785e86e0e52 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 01:55:19 +0000 Subject: [PATCH 25/50] fix[chunked-prefill]: add logs --- src/parallax/server/executor/base_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 7209bfc1..87c3de8b 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -564,6 +564,7 @@ def run_loop(self): self.handle_input_requests(chunked_reqs) # 8. Dispatch to the appropriate destination if to_forward_reqs or chunked_reqs: + logger.debug(f"dispatch to_forward and chunked_reqs to next peer.") if self.is_last_peer and self.is_first_peer and (to_forward_reqs is not None and len(to_forward_reqs) > 0): # Single node: handle to_forward locally logger.debug(f"Handle {len(to_forward_reqs)} to_forward requests.") From 0d45d66ca7d3de391ccac2cd50d4986e1f9d66fa Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:02:57 +0000 Subject: [PATCH 26/50] fix[chunked-prefill]: add logs --- src/parallax/server/executor/base_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 87c3de8b..de7e6b34 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -571,6 +571,7 @@ def run_loop(self): self.handle_input_requests(to_forward_reqs) elif self.tp_rank == 0: # Send to_forward to next peer (do not send chunked_reqs if self is last_peer) + logger.debug(f"self is last_peer: {self.is_last_peer}, self is first_peer: {self.is_first_peer}") if not self.is_last_peer: logger.debug(f"Send {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs to next peer.") self.send_to_peer_socket.send_multipart( From dd45122edd9aedcfa90a19a51207be8e4d36c58e Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:08:17 +0000 Subject: [PATCH 27/50] fix[chunked-prefill]: add logs --- src/parallax/server/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index e70c11f4..436ef1b5 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -239,6 +239,7 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: + logger.debug(f"Request {rid} already in running requests. skip admit.") continue # Check kv cache pool From 2b87e5c30ba245e5be736797d6c86cdfcc157e42 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:18:36 +0000 Subject: [PATCH 28/50] fix[chunked-prefill]: add logs --- src/parallax/server/scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 436ef1b5..24e2dc1b 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -313,6 +313,9 @@ def form_batch(self) -> List[Request]: self.admit_requests() if not self._running_requests: return [] + + logger.debug(f"Form batch with {len(self._running_requests)} running requests.") + logger.debug(f"Running requests: {[(req.request_id, req.status, req.ready_for_next_step) for req in self._running_requests.values()]}") inflight_tokens = 0 batch: List[Request] = [] From 1f0373bf483cf79d1e0ed305174702ea41c4355e Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:29:21 +0000 Subject: [PATCH 29/50] fix[chunked-prefill]: if req in running_batch, change status --- src/parallax/server/scheduler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 24e2dc1b..e760ee62 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -239,7 +239,9 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: - logger.debug(f"Request {rid} already in running requests. skip admit.") + self._running_requests[rid].ready_for_next_step = True + self._running_requests[rid].last_updated_time = time.time() + logger.debug(f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time.") continue # Check kv cache pool @@ -314,9 +316,6 @@ def form_batch(self) -> List[Request]: if not self._running_requests: return [] - logger.debug(f"Form batch with {len(self._running_requests)} running requests.") - logger.debug(f"Running requests: {[(req.request_id, req.status, req.ready_for_next_step) for req in self._running_requests.values()]}") - inflight_tokens = 0 batch: List[Request] = [] From 4de2713d6efe6fd3f0921b2de9ef6b92e04361cc Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:39:54 +0000 Subject: [PATCH 30/50] fix[chunked-prefill]: add logs --- src/parallax/server/executor/base_executor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index de7e6b34..08405eeb 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -299,6 +299,7 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: if recv_req is not None and len(recv_req) > 0: for req in recv_req: if req.hidden_states is not None: + logger.debug(f"recv request {req.request_id} hidden_states.length: {req.hidden_states.size()}") if req.hidden_states.dtype != self.dtype: logger.debug( f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" From 8a930a6d70d689fbbfa5e82fad90172dd53bca2d Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 02:52:31 +0000 Subject: [PATCH 31/50] fix[chunked-prefill]: admit req update if req exists --- src/parallax/server/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index e760ee62..4161b792 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -239,8 +239,7 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: - self._running_requests[rid].ready_for_next_step = True - self._running_requests[rid].last_updated_time = time.time() + self._running_requests[rid] = req logger.debug(f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time.") continue From 2b5f9ee5f163afc8b3b76b9a7a5dbd3eb3ac0748 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 03:23:02 +0000 Subject: [PATCH 32/50] fix[chunked-prefill]: add param for handle-input-reqs --- src/parallax/server/executor/base_executor.py | 4 ++-- src/parallax/server/executor/sglang_executor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 08405eeb..1ab90a0e 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -211,7 +211,7 @@ def __init__( ) @abstractmethod - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" @abstractmethod @@ -479,7 +479,7 @@ def run_loop(self): if self.enable_weight_refit: self.check_and_refit_weight(refit_weight_path) - self.handle_input_requests(received_requests) + self.handle_input_requests(received_requests, from_previous_peer=True) # Send abort signals to P2P server to broadcast to all nodes if len(self.finished_batch) > 0 and self.tp_rank == 0: self.send_to_peer_socket.send_multipart( diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 4e791c82..46066271 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -357,7 +357,7 @@ def prepare_next_batch_requests( logger.debug(f"sglang_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests") return base_chunked, base_to_forward - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) @@ -461,7 +461,7 @@ def handle_input_requests(self, requests: List[Request]): self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) - elif self.chunked_req is not None and self.chunked_req.rid == req.request_id and self.chunked_req.is_chunked > 0: + elif self.chunked_req is not None and self.chunked_req.rid == req.request_id and self.chunked_req.is_chunked > 0 and not from_previous_peer: self.chunked_req.is_chunked -= 1 req.status = RequestStatus.PREFILLING continue From f070057de3065b97d0e25a0e3dbd2c420563c995 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Sun, 1 Feb 2026 03:47:33 +0000 Subject: [PATCH 33/50] fix[chunked-prefill]: every req can admit once per admit --- src/parallax/server/scheduler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 4161b792..8f81cbb8 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -21,7 +21,7 @@ import time from collections import OrderedDict, deque -from typing import Deque, Dict, List, Optional +from typing import Deque, Dict, List, Optional, Set from parallax.server.cache_manager import CacheManager from parallax.server.request import InitialRequest, Request, RequestStatus @@ -234,10 +234,15 @@ def admit_requests(self): """Move requests from wait queue into running (inflight) set, up to capacity. Pushes admitted requests directly into the running set. + Each request is updated or added to running at most once per call. """ + seen_this_call: Set[str] = set() while self._wait_queue and len(self._running_requests) < self.max_batch_size: req = self._wait_queue.popleft() rid = req.request_id + if rid in seen_this_call: + continue + seen_this_call.add(rid) if rid in self._running_requests: self._running_requests[rid] = req logger.debug(f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time.") From 80e80f32d2ab6e83d7e2246eab99cab03ae85439 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 01:14:01 +0000 Subject: [PATCH 34/50] feat[chunked-prefill]: base chunked-prefill for multi-gpu --- src/parallax/server/executor/base_executor.py | 26 ++++++-- .../server/executor/sglang_executor.py | 64 ++++++------------- src/parallax/server/scheduler.py | 6 +- src/parallax/sglang/batch_info.py | 28 ++++++-- 4 files changed, 66 insertions(+), 58 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 1ab90a0e..621e9ad8 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -299,7 +299,9 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: if recv_req is not None and len(recv_req) > 0: for req in recv_req: if req.hidden_states is not None: - logger.debug(f"recv request {req.request_id} hidden_states.length: {req.hidden_states.size()}") + logger.debug( + f"recv request {req.request_id} hidden_states.length: {req.hidden_states.size()}" + ) if req.hidden_states.dtype != self.dtype: logger.debug( f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" @@ -566,15 +568,23 @@ def run_loop(self): # 8. Dispatch to the appropriate destination if to_forward_reqs or chunked_reqs: logger.debug(f"dispatch to_forward and chunked_reqs to next peer.") - if self.is_last_peer and self.is_first_peer and (to_forward_reqs is not None and len(to_forward_reqs) > 0): + if ( + self.is_last_peer + and self.is_first_peer + and (to_forward_reqs is not None and len(to_forward_reqs) > 0) + ): # Single node: handle to_forward locally logger.debug(f"Handle {len(to_forward_reqs)} to_forward requests.") self.handle_input_requests(to_forward_reqs) elif self.tp_rank == 0: # Send to_forward to next peer (do not send chunked_reqs if self is last_peer) - logger.debug(f"self is last_peer: {self.is_last_peer}, self is first_peer: {self.is_first_peer}") + logger.debug( + f"self is last_peer: {self.is_last_peer}, self is first_peer: {self.is_first_peer}" + ) if not self.is_last_peer: - logger.debug(f"Send {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs to next peer.") + logger.debug( + f"Send {len(to_forward_reqs + chunked_reqs)} to_forward and chunked_reqs to next peer." + ) self.send_to_peer_socket.send_multipart( [ b"forward", @@ -588,11 +598,15 @@ def run_loop(self): f"in {(time.time() - start_time) * 1000:.3f} ms" ) elif to_forward_reqs is not None and len(to_forward_reqs) > 0: - logger.debug(f"Send {len(to_forward_reqs)} to_forward to next peer.") + logger.debug( + f"Send {len(to_forward_reqs)} to_forward to next peer." + ) self.send_to_peer_socket.send_multipart( [ b"forward", - request_to_proto(to_forward_reqs, self.device).SerializeToString(), + request_to_proto( + to_forward_reqs, self.device + ).SerializeToString(), ] ) logger.debug( diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 46066271..044ddbab 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -319,7 +319,6 @@ def check_lora_server_args(self): ), "--max-lora-chunk-size must be a power of 2 between 16 and 128." def stash_chunked_request(self, req: Req): - logger.debug(f"req.fill_ids_size: {len(req.fill_ids)}") # #endregion if req.req_pool_idx is None: logger.warning( @@ -343,8 +342,14 @@ def prepare_next_batch_requests( base_chunked, base_to_forward = super().prepare_next_batch_requests( requests, batch_output, context_lengths ) - if self.chunked_req is None or self.chunked_req.is_chunked <= 0 or self.chunked_req.rid not in [req.request_id for req in requests]: - logger.debug(f"sglang_executor: prepare_next_batch_requests: return base_chunked and base_to_forward because chunked_req is None or is_chunked <= 0 or rid not in requests") + if ( + self.chunked_req is None + or self.chunked_req.is_chunked <= 0 + or self.chunked_req.rid not in [req.request_id for req in requests] + ): + logger.debug( + f"sglang_executor: prepare_next_batch_requests: return base_chunked and base_to_forward because chunked_req is None or is_chunked <= 0 or rid not in requests" + ) return base_chunked, base_to_forward chunked_rid = self.chunked_req.rid self.stash_chunked_request(self.chunked_req) @@ -354,7 +359,9 @@ def prepare_next_batch_requests( base_chunked.append(req) break base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] - logger.debug(f"sglang_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests") + logger.debug( + f"sglang_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests" + ) return base_chunked, base_to_forward def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): @@ -461,10 +468,15 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo self.release_and_evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) - elif self.chunked_req is not None and self.chunked_req.rid == req.request_id and self.chunked_req.is_chunked > 0 and not from_previous_peer: - self.chunked_req.is_chunked -= 1 - req.status = RequestStatus.PREFILLING - continue + elif ( + self.chunked_req is not None + and self.chunked_req.rid == req.request_id + and self.chunked_req.is_chunked > 0 + and not from_previous_peer + ): + self.chunked_req.is_chunked -= 1 + req.status = RequestStatus.PREFILLING + continue else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) @@ -498,7 +510,6 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: if self.cur_batch and self.cur_batch.forward_mode.is_extend(): if self.cur_batch.chunked_req is not None and self.cur_batch.chunked_req.is_chunked > 0: - logger.debug(f"exclude chunked_req {self.cur_batch.chunked_req.rid} from running_batch") chunked_req_to_exclude.add(self.cur_batch.chunked_req) if self.cur_batch: @@ -526,43 +537,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: # Return appropriate output based on peer position if return_decoded_tokens: - # Debug: log running_batch vs prepared_inputs["requests"] to detect - # "running_batch still has previous request" causing token mis-assignment - running_batch_size = ( - 0 if self.running_batch.is_empty() else len(self.running_batch.reqs) - ) - running_batch_rids = ( - [] - if self.running_batch.is_empty() - else [req.rid for req in self.running_batch.reqs] - ) - requests_len = len(requests) - requests_ids = [getattr(r, "request_id", str(r)) for r in requests] - logger.debug( - "[ChunkedPrefill-Debug] process_batch decode: running_batch size=%s rids=%s, " - "prepared_inputs requests len=%s request_ids=%s", - running_batch_size, - running_batch_rids, - requests_len, - requests_ids, - ) - if running_batch_size != requests_len: - logger.warning( - "[ChunkedPrefill-Debug] MISMATCH: running_batch has %s reqs but prepared_inputs " - "has %s requests; token indices may be assigned to wrong request. " - "running_batch_rids=%s, request_ids=%s", - running_batch_size, - requests_len, - running_batch_rids, - requests_ids, - ) - # Last peer: sample and return token IDs next_token_ids = self.model_runner.sample(logits_output, forward_batch) - logger.debug( - "[ChunkedPrefill-Debug] process_batch after sample: len(next_token_ids)=%s", - len(next_token_ids), - ) # Only compute probs if any request in the batch needs it # Check if any InitialRequest has return_probs=True diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 8f81cbb8..50c59466 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -245,7 +245,9 @@ def admit_requests(self): seen_this_call.add(rid) if rid in self._running_requests: self._running_requests[rid] = req - logger.debug(f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time.") + logger.debug( + f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time." + ) continue # Check kv cache pool @@ -319,7 +321,7 @@ def form_batch(self) -> List[Request]: self.admit_requests() if not self._running_requests: return [] - + inflight_tokens = 0 batch: List[Request] = [] diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 24dd4381..785ceb04 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -94,12 +94,16 @@ def transform_requests_to_sglang( # Save rid of chunked req added this round so we can reorder can_run_list to match old_requests. chunked_rid = executor.chunked_req.rid if executor.chunked_req is not None else None - if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in old_requests]: + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in old_requests + ]: logger.debug(f"before add_chunked_req, chunked_req is not None") executor.chunked_req.init_next_round_input(page_tree_cache) executor.chunked_req = adder.add_chunked_req(executor.chunked_req) if executor.chunked_req is None: - logger.debug(f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None") + logger.debug( + f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None" + ) reqs = [] logger.debug(f"old_req size: {len(old_requests)}") @@ -164,14 +168,19 @@ def transform_requests_to_sglang( executor.chunked_req = adder.new_chunked_req logger.debug(f"new chunked_req is {executor.chunked_req}") - if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in old_requests]: + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in old_requests + ]: executor.chunked_req.is_chunked += 1 # Reorder so returned list follows old_requests order and each element is the Req # that corresponds to that old_req (same rid). Use rid to map instead of assuming # can_run_list order, so the relationship with reqs is explicit. can_run_list = adder.can_run_list - if chunked_rid is None and (executor.chunked_req is None or executor.chunked_req.rid not in [req.request_id for req in old_requests]): + if chunked_rid is None and ( + executor.chunked_req is None + or executor.chunked_req.rid not in [req.request_id for req in old_requests] + ): return can_run_list rid_to_req = {req.rid: req for req in can_run_list} reordered: List[Req] = [rid_to_req[old_req.request_id] for old_req in old_requests] @@ -206,10 +215,17 @@ def dummy_evict(*args): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, - chunked_req=executor.chunked_req if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in requests] else None, + chunked_req=( + executor.chunked_req + if executor.chunked_req is not None + and executor.chunked_req.rid in [req.request_id for req in requests] + else None + ), ) schedule_batch.prepare_for_extend() - if executor.chunked_req is not None and executor.chunked_req.rid in [req.request_id for req in requests]: + if executor.chunked_req is not None and executor.chunked_req.rid in [ + req.request_id for req in requests + ]: logger.debug( f"chunked_req.rid={executor.chunked_req.rid}, chunked_req.req_pool_idx: {executor.chunked_req.req_pool_idx}" ) From 635ed65a6bd873ccd78dfd2d0bb7210688f4735c Mon Sep 17 00:00:00 2001 From: wasamtc <81901970+wasamtc@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:21:52 +0800 Subject: [PATCH 35/50] Restore cron schedule for automatic image builds Re-enable scheduled builds with a cron job. --- .github/workflows/build-images.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-images.yaml b/.github/workflows/build-images.yaml index 751dc8ad..64844408 100644 --- a/.github/workflows/build-images.yaml +++ b/.github/workflows/build-images.yaml @@ -1,9 +1,10 @@ name: Build Images on: - # schedule: # 注释掉定时任务以停止自动构建 - # - cron: '0 0 * * *' - workflow_dispatch: # 保留手动触发开关 + schedule: + - cron: '0 0 * * *' + workflow_dispatch: + env: IMAGE_NAME: parallax @@ -28,6 +29,7 @@ jobs: driver: remote endpoint: tcp://buildkit-buildkit-service.arc-systems:1234 + - name: Log in to Docker Hub uses: docker/login-action@v3 with: From a1c81effad22e4322b26ad16f1c8263a2faab358 Mon Sep 17 00:00:00 2001 From: wasamtc <81901970+wasamtc@users.noreply.github.com> Date: Mon, 2 Feb 2026 09:22:12 +0800 Subject: [PATCH 36/50] Enable scheduled workflow for building Spark image --- .github/workflows/build-spark-image.yaml | 51 +++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build-spark-image.yaml b/.github/workflows/build-spark-image.yaml index 1474995d..0d3b60d6 100644 --- a/.github/workflows/build-spark-image.yaml +++ b/.github/workflows/build-spark-image.yaml @@ -1,9 +1,10 @@ name: Build Spark Image on: - # schedule: # 注释掉或删除此部分以停止自动触发 - # - cron: '0 3 * * *' - workflow_dispatch: # 仅保留手动触发 + schedule: + - cron: '0 3 * * *' + workflow_dispatch: + env: IMAGE_NAME: parallax @@ -22,9 +23,41 @@ jobs: variant: [spark] steps: - - name: Workflow disabled – no actions performed - run: | - echo "==================================================" - echo "This workflow has been DISABLED." - echo "No Docker images are built or pushed." - echo "==================================================" + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + with: + driver: remote + endpoint: tcp://buildkit-buildkit-service.arc-systems:1234 + + + - name: Log in to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.NAMESPACE }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch,suffix=-${{ matrix.variant }} + type=ref,event=pr,suffix=-${{ matrix.variant }} + type=raw,value=latest-${{ matrix.variant }},enable={{is_default_branch}} + + - name: Build and push Docker image + id: build + uses: docker/build-push-action@v5 + with: + context: . + file: ./docker/Dockerfile.${{ matrix.variant }} + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/arm64 From 383e5bce1a8a6e333af4417795fc4194c47ce59a Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 01:31:44 +0000 Subject: [PATCH 37/50] fix[mlx-executor]: add param from mlx-executor --- src/parallax/server/executor/mlx_executor.py | 2 ++ src/parallax/server/executor/vllm_executor.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 268e035c..41206c9a 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -91,6 +91,8 @@ def __init__( # Weight Refit enable_weight_refit: Optional[bool] = False, weight_refit_mode: Optional[str] = "disk", + # Chunked prefill (SGL-only; accepted here for factory config compatibility) + chunked_prefill_size: Optional[int] = None, # Pipe communication conn: Optional[List[Any]] = [], ): diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index c5d7de08..8c96c470 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -89,6 +89,8 @@ def __init__( weight_refit_mode: Optional[str] = "disk", # Routed experts enable_return_routed_experts: bool = False, + # Chunked prefill (SGL-only; accepted here for factory config compatibility) + chunked_prefill_size: Optional[int] = None, # Pipe communication conn: Optional[List[Any]] = [], ): From 3fcd0745814df0d06058d7a20f473320a62df0c4 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 02:03:09 +0000 Subject: [PATCH 38/50] fix[admit-req]: reenque already exists reqs --- src/parallax/server/executor/sglang_executor.py | 1 + src/parallax/server/scheduler.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 044ddbab..d4c9421f 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -476,6 +476,7 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo ): self.chunked_req.is_chunked -= 1 req.status = RequestStatus.PREFILLING + self.scheduler.evict_request(req.request_id) continue else: # This is an active request, add it to the scheduler queue to be processed. diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index fa12bd17..828c6f26 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -21,7 +21,7 @@ import time from collections import OrderedDict, deque -from typing import Deque, Dict, List, Optional, Set +from typing import Deque, Dict, List, Optional from parallax.server.cache_manager import CacheManager from parallax.server.request import InitialRequest, Request, RequestStatus @@ -244,18 +244,15 @@ def admit_requests(self): Pushes admitted requests directly into the running set. Each request is updated or added to running at most once per call. """ - seen_this_call: Set[str] = set() - while self._wait_queue and len(self._running_requests) < self.max_batch_size: + # One pass over wait_queue to avoid infinite loop when all front items are already in running_requests + initial_len = len(self._wait_queue) + for _ in range(initial_len): + if not self._wait_queue or len(self._running_requests) >= self.max_batch_size: + break req = self._wait_queue.popleft() rid = req.request_id - if rid in seen_this_call: - continue - seen_this_call.add(rid) if rid in self._running_requests: - self._running_requests[rid] = req - logger.debug( - f"Request {rid} already in running requests. update ready_for_next_step and last_updated_time." - ) + self._wait_queue.append(req) continue # Check kv cache pool From d2b10c6370de65587310d00c2a306fbb05223237 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 02:05:40 +0000 Subject: [PATCH 39/50] fix[mac-test]: reenque already exists reqs --- src/parallax/server/executor/mlx_executor.py | 2 +- src/parallax/server/executor/vllm_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 41206c9a..cfdb7478 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -272,7 +272,7 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj): data = pickle.loads(np.array(data_arr).tobytes()) return data - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) diff --git a/src/parallax/server/executor/vllm_executor.py b/src/parallax/server/executor/vllm_executor.py index 8c96c470..9751dbfd 100755 --- a/src/parallax/server/executor/vllm_executor.py +++ b/src/parallax/server/executor/vllm_executor.py @@ -206,7 +206,7 @@ def check_lora_server_args(self): "--enable-lora is set to False, any provided lora_paths will be ignored." ) - def handle_input_requests(self, requests: List[Request]): + def handle_input_requests(self, requests: List[Request], from_previous_peer: bool = False): """Update requests states and status in scheduler and cache manager.""" if self.tp_size > 1: requests = self._tensor_parallel_broadcast_pyobj(requests) From 0fcbc5a1481b2fe8c4753ff28cffbb7420040ee6 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 00:15:36 -0800 Subject: [PATCH 40/50] feat[chunked-prefill]: change chunked-prefill for mac --- src/parallax/server/executor/base_executor.py | 4 +- src/parallax/server/executor/mlx_executor.py | 101 +++++++++++++++--- src/parallax/server/request.py | 3 + src/parallax/utils/mac_prefill_addr.py | 68 ++++++++++++ 4 files changed, 160 insertions(+), 16 deletions(-) create mode 100644 src/parallax/utils/mac_prefill_addr.py diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index ddcda724..c4bf96ec 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -299,8 +299,10 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: if recv_req is not None and len(recv_req) > 0: for req in recv_req: if req.hidden_states is not None: + size_attr = getattr(req.hidden_states, "size", None) + hidden_size = size_attr() if callable(size_attr) else size_attr logger.debug( - f"recv request {req.request_id} hidden_states.length: {req.hidden_states.size()}" + f"recv request {req.request_id} hidden_states.length: {hidden_size}" ) if req.hidden_states.dtype != self.dtype: logger.debug( diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index cfdb7478..7334873f 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -28,6 +28,7 @@ pad_inputs, ) from parallax_utils.logging_config import get_logger +from parallax.utils.mac_prefill_addr import AddReqResult, MACPrefillAdder logger = get_logger(__name__) @@ -234,14 +235,13 @@ def __init__( # Prefix Cache Manager self.enable_prefix_cache = enable_prefix_cache - # self.prefix_cache = RadixCache( - # num_kv_heads=num_key_value_heads, - # head_dim=head_dim, - # head_dim_v=v_head_dim, - # num_layers=self.num_shard_layers, - # dtype=self.dtype, - # page_size=1, - # ) + if chunked_prefill_size is not None and chunked_prefill_size > 0: + # up align to page size + self.chunked_prefill_size = (chunked_prefill_size + self.cache_manager.block_size - 1) // self.cache_manager.block_size * self.cache_manager.block_size + else: + self.chunked_prefill_size = None + self.chunked_req = None + self.chunked_req_offset = 0 logger.debug( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) @@ -304,7 +304,10 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo continue if not req.abort and req.next_token_id is not None: - original_req.commit_new_token(req.next_token_id) + if self.chunked_req is not None and req.request_id == self.chunked_req.rid and self.chunked_req.is_chunked > 0: + original_req.status = RequestStatus.PREFILLING + else: + original_req.commit_new_token(req.next_token_id) if len(req.routing_table) > 0: original_req.routing_table = req.routing_table @@ -313,7 +316,12 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo if req.abort: original_req.abort = True - if self.scheduler.check_and_update_request_status(original_req): + + if self.chunked_req is not None and req.request_id == self.chunked_req.rid and self.chunked_req.is_chunked > 0: + self.chunked_req.is_chunked -= 1 + self.scheduler.enque_request(original_req) + continue + elif self.scheduler.check_and_update_request_status(original_req): self.cache_manager.release_request(original_req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " @@ -357,10 +365,10 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - if self.enable_prefix_cache: - keys, values = self.cache_manager.gather_kv_cache(req.request_id) - self.prefix_cache.cache_finished_request(req, keys, values) - self.prefix_cache.evict_request(req.request_id) + # if self.enable_prefix_cache: + # keys, values = self.cache_manager.gather_kv_cache(req.request_id) + # self.prefix_cache.cache_finished_request(req, keys, values) + # self.prefix_cache.evict_request(req.request_id) self.cache_manager.release_request(req.request_id) logger.debug( @@ -370,9 +378,39 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo self.scheduler.evict_request(req.request_id) if not self.is_last_peer and not req.abort: self.finished_batch.append(req) + elif ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + and not from_previous_peer + ): + self.chunked_req.is_chunked -= 1 + req.status = RequestStatus.PREFILLING + self.scheduler.evict_request(req.request_id) + continue else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) + + def prepare_next_batch_requests(self, requests: List[Request], batch_output: Any, context_lengths: Any) -> Tuple[List[Request], List[Request]]: + """Prepares a batch of requests for the next stage of the pipeline.""" + base_chunked, base_to_forward = super().prepare_next_batch_requests(requests, batch_output, context_lengths) + if ( + self.chunked_req is None + or self.chunked_req.is_chunked <= 0 + or self.chunked_req.rid not in [req.request_id for req in requests] + ): + logger.debug(f"mlx_executor: prepare_next_batch_requests: return base_chunked{len(base_chunked)} and base_to_forward{len(base_to_forward)} because chunked_req is None or is_chunked <= 0 or rid not in requests") + return base_chunked, base_to_forward + chunked_rid = self.chunked_req.rid + for req in base_to_forward: + if req.request_id == chunked_rid: + req.status = RequestStatus.PREFILLING + base_chunked.append(req) + break + base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] + logger.debug(f"mlx_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests") + return base_chunked, base_to_forward def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): """Process a batch of requests in MLX.""" @@ -522,6 +560,39 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A batch_size = len(batched_requests) if batch_size == 0: return None + + original_batched_requests = batched_requests + logger.debug(f"original_batched_requests_size: {len(original_batched_requests)}") + + adder = MACPrefillAdder(self.cache_manager.block_size, self.chunked_prefill_size, self.chunked_req_offset) + + chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None + + if self.chunked_req is not None and chunked_rid in [req.request_id for req in original_batched_requests]: + logger.debug(f"before add_chunked_req, chunked_req is not None") + self.chunked_req = adder.add_chunked_req(self.chunked_req) + if self.chunked_req is None: + logger.debug(f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None") + + for old_req in original_batched_requests: + if chunked_rid is not None and old_req.request_id == chunked_rid: + continue + res = adder.add_one_req(old_req) + if res != AddReqResult.CONTINUE: + logger.debug(f"macprefilladder has no token to add to prefill batch") + break + + if adder.new_chunked_req is not None: + self.chunked_req = adder.new_chunked_req + logger.debug(f"new chunked_req is {self.chunked_req.rid}") + + if self.chunked_req is not None and self.chunked_req.rid in [req.request_id for req in original_batched_requests]: + self.chunked_req.is_chunked += 1 + + can_run_by_id = {req.request_id: req for req in adder.can_run_list} + batched_requests = [can_run_by_id[req.request_id] for req in original_batched_requests if req.request_id in can_run_by_id] + self.chunked_req_offset = adder.chunked_req_offset + logger.debug(f"after add_one_req, batched_requests size: {len(batched_requests)}") h_or_tokens_list = [] block_tables_list = [] @@ -652,7 +723,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A "h_or_tokens": padded_inputs, "cache": self.cache_manager.get_caches(), "mask": mask, - "requests": batched_requests, + "requests": original_batched_requests, "block_tables": block_tables_tensor, "context_lengths": context_lengths_tensor, "slot_mapping": slot_mapping_tensor, diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f7c9bb90..e2d0bb9b 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -109,6 +109,9 @@ def __init__( self.last_updated_time: Optional[float] = None self.lora_id: Optional[str] = None self.lora_path = lora_path + self.is_chunked = 0 + self.rid = self.request_id + self.origin_input_ids = input_ids @property def is_finished(self) -> bool: diff --git a/src/parallax/utils/mac_prefill_addr.py b/src/parallax/utils/mac_prefill_addr.py new file mode 100644 index 00000000..77bf5b56 --- /dev/null +++ b/src/parallax/utils/mac_prefill_addr.py @@ -0,0 +1,68 @@ +from enum import Enum, auto + +from mpmath import extend +from parallax.server.request import Request + + +class AddReqResult(Enum): + CONTINUE = auto() # Continue to add requests + NO_TOKEN = auto() # No token left + OTHER = auto() # Other reasons to stop adding requests + +class MACPrefillAdder: + """ + MACPrefillAdder is a class that adds prefill requests to the MAC prefill batch. + """ + def __init__( + self, + page_size: int, + rem_chunk_tokens: int, + chunked_req_offset: int + ): + self.page_size = page_size + self.rem_chunk_tokens = rem_chunk_tokens + self.can_run_list = [] + self.new_chunked_req = None + self.chunked_req_offset = chunked_req_offset + def add_chunked_req(self, chunked_req: Request) -> Request: + if chunked_req is None: + return None + extend_input_len = len(chunked_req.origin_input_ids) - self.chunked_req_offset + truncated = extend_input_len > self.rem_chunk_tokens + self.chunked_req_offset += min(self.rem_chunk_tokens, extend_input_len) + self.chunked_req_offset = min(self.chunked_req_offset, len(chunked_req.origin_input_ids)) + chunked_req.input_ids = chunked_req.origin_input_ids[: self.chunked_req_offset] + chunked_req.total_len = self.chunked_req_offset + self.can_run_list.append(chunked_req) + self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) + self.chunked_req_offset = 0 if not truncated else self.chunked_req_offset + return chunked_req if truncated else None + + def add_one_req(self, req: Request) -> AddReqResult: + extend_input_len = len(req.origin_input_ids) - self.chunked_req_offset + # align to page size + extend_input_len = (extend_input_len + self.page_size - 1) // self.page_size * self.page_size + if self.rem_chunk_tokens is None or extend_input_len <= self.rem_chunk_tokens: + self.chunked_req_offset += extend_input_len + self.chunked_req_offset = min(self.chunked_req_offset, len(req.origin_input_ids)) + req.input_ids = req.origin_input_ids[: self.chunked_req_offset] + req.total_len = self.chunked_req_offset + self.can_run_list.append(req) + self.chunked_req_offset = 0 + self.rem_chunk_tokens -= extend_input_len + else: + # make sure at least one page is available + trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size + if trunc_len <= 0: + return AddReqResult.OTHER + extend_input_len = trunc_len + self.chunked_req_offset = extend_input_len + self.chunked_req_offset = min(self.chunked_req_offset, len(req.origin_input_ids)) + req.input_ids = req.origin_input_ids[: self.chunked_req_offset] + req.total_len = self.chunked_req_offset + self.can_run_list.append(req) + self.new_chunked_req = req + self.rem_chunk_tokens -= extend_input_len + if self.rem_chunk_tokens is None or self.rem_chunk_tokens <= 0: + return AddReqResult.OTHER + return AddReqResult.CONTINUE \ No newline at end of file From 3b63715bbd32100b8206bba89a414b6176bc4605 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 01:36:06 -0800 Subject: [PATCH 41/50] fix[chunked-prefill]:prepare next req use origin_input_ids instead of input_ids --- src/parallax/server/executor/base_executor.py | 4 ++-- src/parallax/server/executor/mlx_executor.py | 2 ++ src/parallax/server/request.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index c4bf96ec..08a77267 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -790,7 +790,7 @@ def _prepare_next_single_request( request_id=request.request_id, status=RequestStatus.DECODING, current_position=request.total_length + 1, - input_ids=request.input_ids, + input_ids=request.origin_input_ids, hidden_states=hidden_states, next_token_id=next_token_id, routing_table=request.routing_table, @@ -809,7 +809,7 @@ def _prepare_next_single_request( request_id=request.request_id, status=RequestStatus.DECODING, # Last peer always changes status to DECODING current_position=request.total_length, - input_ids=request.input_ids, + input_ids=request.origin_input_ids, hidden_states=hidden_states, next_token_id=next_token_id, routing_table=request.routing_table, diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 7334873f..2cce1447 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -319,6 +319,7 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo if self.chunked_req is not None and req.request_id == self.chunked_req.rid and self.chunked_req.is_chunked > 0: self.chunked_req.is_chunked -= 1 + self.cache_manager.release_request(original_req.request_id) self.scheduler.enque_request(original_req) continue elif self.scheduler.check_and_update_request_status(original_req): @@ -386,6 +387,7 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo ): self.chunked_req.is_chunked -= 1 req.status = RequestStatus.PREFILLING + self.cache_manager.release_request(req.request_id) self.scheduler.evict_request(req.request_id) continue else: diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index e2d0bb9b..60bfafee 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -338,7 +338,7 @@ def from_initial_request( return IntermediateRequest( request_id=initial_request.request_id, status=initial_request.status, - input_ids=initial_request.input_ids, + input_ids=initial_request.origin_input_ids, next_token_id=next_token_id, current_position=initial_request.total_length, hidden_states=hidden_states, @@ -365,7 +365,7 @@ def from_intermediate_request( request_id=old_request.request_id, status=old_request.status, current_position=old_request.total_length, - input_ids=old_request.input_ids, + input_ids=old_request.origin_input_ids, next_token_id=old_request.next_token_id, hidden_states=new_hidden_states, routing_table=old_request.routing_table, From 4fda2468f163fb73f880fcd88264fefc6bcba8e4 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 01:44:16 -0800 Subject: [PATCH 42/50] fix[chunked-prefill]:enqueue req use origin_input_ids instead of input_ids --- src/parallax/server/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 828c6f26..d2d7a562 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -126,6 +126,8 @@ def enque_request(self, request: Request | str): request.ready_for_next_step = True request.last_updated_time = time.time() + if request.origin_input_ids is not None: + request.input_ids = request.origin_input_ids # TODO: Handle chunked prefill. if request.is_decoding: rid = request.request_id From 92def224bbdf4980f1603de3a6c096e2d39ee8ef Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 18:36:03 -0800 Subject: [PATCH 43/50] fix[chunked-prefill]: get matched tokens before chunked --- src/parallax/server/executor/mlx_executor.py | 9 +++-- src/parallax/server/scheduler.py | 42 ++++++++++---------- src/parallax/utils/mac_prefill_addr.py | 37 ++++++++--------- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 2cce1447..c4ce80a4 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -241,7 +241,6 @@ def __init__( else: self.chunked_prefill_size = None self.chunked_req = None - self.chunked_req_offset = 0 logger.debug( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) @@ -565,8 +564,11 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A original_batched_requests = batched_requests logger.debug(f"original_batched_requests_size: {len(original_batched_requests)}") - - adder = MACPrefillAdder(self.cache_manager.block_size, self.chunked_prefill_size, self.chunked_req_offset) + adder = MACPrefillAdder( + self.cache_manager.block_size, + self.chunked_prefill_size, + self.cache_manager + ) chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None @@ -593,7 +595,6 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A can_run_by_id = {req.request_id: req for req in adder.can_run_list} batched_requests = [can_run_by_id[req.request_id] for req in original_batched_requests if req.request_id in can_run_by_id] - self.chunked_req_offset = adder.chunked_req_offset logger.debug(f"after add_one_req, batched_requests size: {len(batched_requests)}") h_or_tokens_list = [] diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index d2d7a562..086fadb0 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -257,27 +257,27 @@ def admit_requests(self): self._wait_queue.append(req) continue - # Check kv cache pool - if self.cache_manager is not None: - if not self.cache_manager.has_request(req.request_id): - # TODO: Handle chunked prefill, and support preemption. - # Pass input_ids for prefix cache matching - token_ids = getattr(req, "input_ids", None) - success, matched_tokens = self.cache_manager.allocate_request( - req.request_id, req.total_length, token_ids=token_ids - ) - if not success: - logger.warning( - f"Request {rid} can't be admit to running batch due to KV cache size." - ) - # Put back to wait queue if allocation fails - self._wait_queue.appendleft(req) - # Stop admitting since we are out of memory - break - if matched_tokens > 0: - logger.debug( - f"Request {rid} matched {matched_tokens} tokens from prefix cache" - ) + # # Check kv cache pool + # if self.cache_manager is not None: + # if not self.cache_manager.has_request(req.request_id): + # # TODO: Handle chunked prefill, and support preemption. + # # Pass input_ids for prefix cache matching + # token_ids = getattr(req, "input_ids", None) + # success, matched_tokens = self.cache_manager.allocate_request( + # req.request_id, req.total_length, token_ids=token_ids + # ) + # if not success: + # logger.warning( + # f"Request {rid} can't be admit to running batch due to KV cache size." + # ) + # # Put back to wait queue if allocation fails + # self._wait_queue.appendleft(req) + # # Stop admitting since we are out of memory + # break + # if matched_tokens > 0: + # logger.debug( + # f"Request {rid} matched {matched_tokens} tokens from prefix cache" + # ) # Add request to running requests self._running_requests[rid] = req diff --git a/src/parallax/utils/mac_prefill_addr.py b/src/parallax/utils/mac_prefill_addr.py index 77bf5b56..7675b03e 100644 --- a/src/parallax/utils/mac_prefill_addr.py +++ b/src/parallax/utils/mac_prefill_addr.py @@ -1,6 +1,7 @@ from enum import Enum, auto from mpmath import extend +from parallax.server.cache_manager import CacheManager from parallax.server.request import Request @@ -17,38 +18,39 @@ def __init__( self, page_size: int, rem_chunk_tokens: int, - chunked_req_offset: int + cache_manager: CacheManager ): self.page_size = page_size self.rem_chunk_tokens = rem_chunk_tokens self.can_run_list = [] self.new_chunked_req = None - self.chunked_req_offset = chunked_req_offset + self.cache_manager = cache_manager def add_chunked_req(self, chunked_req: Request) -> Request: if chunked_req is None: return None - extend_input_len = len(chunked_req.origin_input_ids) - self.chunked_req_offset + matched_tokens = 0 + if self.cache_manager.prefix_cache is not None: + _, matched_tokens = self.cache_manager.prefix_cache.match_prefix(chunked_req.origin_input_ids) + extend_input_len = len(chunked_req.origin_input_ids) - matched_tokens + extend_input_len = 1 if extend_input_len <= 0 else extend_input_len truncated = extend_input_len > self.rem_chunk_tokens - self.chunked_req_offset += min(self.rem_chunk_tokens, extend_input_len) - self.chunked_req_offset = min(self.chunked_req_offset, len(chunked_req.origin_input_ids)) - chunked_req.input_ids = chunked_req.origin_input_ids[: self.chunked_req_offset] - chunked_req.total_len = self.chunked_req_offset + chunked_req_offset = min(self.rem_chunk_tokens, extend_input_len) + matched_tokens + chunked_req.input_ids = chunked_req.origin_input_ids[: chunked_req_offset] + chunked_req.total_len = chunked_req_offset self.can_run_list.append(chunked_req) self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) - self.chunked_req_offset = 0 if not truncated else self.chunked_req_offset return chunked_req if truncated else None def add_one_req(self, req: Request) -> AddReqResult: - extend_input_len = len(req.origin_input_ids) - self.chunked_req_offset + matched_tokens = 0 + if self.cache_manager.prefix_cache is not None: + _, matched_tokens = self.cache_manager.prefix_cache.match_prefix(req.origin_input_ids) + extend_input_len = len(req.origin_input_ids) - matched_tokens + extend_input_len = 1 if extend_input_len <= 0 else extend_input_len # align to page size extend_input_len = (extend_input_len + self.page_size - 1) // self.page_size * self.page_size if self.rem_chunk_tokens is None or extend_input_len <= self.rem_chunk_tokens: - self.chunked_req_offset += extend_input_len - self.chunked_req_offset = min(self.chunked_req_offset, len(req.origin_input_ids)) - req.input_ids = req.origin_input_ids[: self.chunked_req_offset] - req.total_len = self.chunked_req_offset self.can_run_list.append(req) - self.chunked_req_offset = 0 self.rem_chunk_tokens -= extend_input_len else: # make sure at least one page is available @@ -56,10 +58,9 @@ def add_one_req(self, req: Request) -> AddReqResult: if trunc_len <= 0: return AddReqResult.OTHER extend_input_len = trunc_len - self.chunked_req_offset = extend_input_len - self.chunked_req_offset = min(self.chunked_req_offset, len(req.origin_input_ids)) - req.input_ids = req.origin_input_ids[: self.chunked_req_offset] - req.total_len = self.chunked_req_offset + chunked_req_offset = extend_input_len + matched_tokens + req.input_ids = req.origin_input_ids[: chunked_req_offset] + req.total_len = chunked_req_offset self.can_run_list.append(req) self.new_chunked_req = req self.rem_chunk_tokens -= extend_input_len From a16f853a83f8610442c7d7a43e1ebbe644805b50 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 19:01:57 -0800 Subject: [PATCH 44/50] fix[chunked-prefill]: use effective total_len if it exits --- src/parallax/server/executor/mlx_executor.py | 2 +- src/parallax/server/request.py | 10 ++++++++-- src/parallax/utils/mac_prefill_addr.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index c4ce80a4..46f60f72 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -611,7 +611,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A token_ids = None if self.enable_prefix_cache and req.input_ids is not None: token_ids = req.input_ids - + logger.debug(f"before allocate_request: {req.request_id}, token_ids length: {len(token_ids)}, req.total_length: {req.total_length}") success, matched_tokens = self.cache_manager.allocate_request( req.request_id, req.total_length, token_ids=token_ids ) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 60bfafee..62ae59dd 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -112,6 +112,8 @@ def __init__( self.is_chunked = 0 self.rid = self.request_id self.origin_input_ids = input_ids + # When set (e.g. by chunked prefill), total_length property returns this instead of computed value. + self._effective_total_length: Optional[int] = None @property def is_finished(self) -> bool: @@ -202,7 +204,9 @@ def output_length(self) -> int: @property def total_length(self) -> int: - """Total length of the sequence (input + output).""" + """Total length of the sequence (input + output). Overridable via _effective_total_length (e.g. chunked prefill).""" + if self._effective_total_length is not None: + return self._effective_total_length return self.prompt_len + self.output_length def get_model_input_for_first_peer(self) -> List[int]: @@ -303,7 +307,9 @@ def input_length(self) -> int: @property def total_length(self) -> int: - """Total length of the sequence (input + output).""" + """Total length of the sequence (input + output). Overridable via _effective_total_length (e.g. chunked prefill).""" + if self._effective_total_length is not None: + return self._effective_total_length return self.current_position @classmethod diff --git a/src/parallax/utils/mac_prefill_addr.py b/src/parallax/utils/mac_prefill_addr.py index 7675b03e..32ffc5f8 100644 --- a/src/parallax/utils/mac_prefill_addr.py +++ b/src/parallax/utils/mac_prefill_addr.py @@ -36,7 +36,7 @@ def add_chunked_req(self, chunked_req: Request) -> Request: truncated = extend_input_len > self.rem_chunk_tokens chunked_req_offset = min(self.rem_chunk_tokens, extend_input_len) + matched_tokens chunked_req.input_ids = chunked_req.origin_input_ids[: chunked_req_offset] - chunked_req.total_len = chunked_req_offset + chunked_req._effective_total_length = chunked_req_offset self.can_run_list.append(chunked_req) self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) return chunked_req if truncated else None @@ -60,7 +60,7 @@ def add_one_req(self, req: Request) -> AddReqResult: extend_input_len = trunc_len chunked_req_offset = extend_input_len + matched_tokens req.input_ids = req.origin_input_ids[: chunked_req_offset] - req.total_len = chunked_req_offset + req._effective_total_length = chunked_req_offset self.can_run_list.append(req) self.new_chunked_req = req self.rem_chunk_tokens -= extend_input_len From c4de011d9b2adcfc049f66bd68ac5095eb30b8da Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 22:40:34 -0800 Subject: [PATCH 45/50] fix[chunked-prefill]: use new reqs instead of old chunked_reqs --- src/parallax/server/executor/base_executor.py | 2 ++ src/parallax/server/executor/mlx_executor.py | 21 ++++++++++++++++++- src/parallax/server/scheduler.py | 2 ++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 08a77267..392f1577 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -304,6 +304,8 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: logger.debug( f"recv request {req.request_id} hidden_states.length: {hidden_size}" ) + shape = req.hidden_states.shape + logger.debug(f"recv request {req.request_id} hidden_states.shape: {shape}") if req.hidden_states.dtype != self.dtype: logger.debug( f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 46f60f72..81d23163 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -564,6 +564,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A original_batched_requests = batched_requests logger.debug(f"original_batched_requests_size: {len(original_batched_requests)}") + for req in original_batched_requests: + logger.debug(f"before prepare_prefill_batch, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") adder = MACPrefillAdder( self.cache_manager.block_size, self.chunked_prefill_size, @@ -573,6 +575,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None if self.chunked_req is not None and chunked_rid in [req.request_id for req in original_batched_requests]: + # update chunked_req use old_req + self.chunked_req = [req for req in original_batched_requests if req.request_id == chunked_rid][0] logger.debug(f"before add_chunked_req, chunked_req is not None") self.chunked_req = adder.add_chunked_req(self.chunked_req) if self.chunked_req is None: @@ -596,6 +600,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A can_run_by_id = {req.request_id: req for req in adder.can_run_list} batched_requests = [can_run_by_id[req.request_id] for req in original_batched_requests if req.request_id in can_run_by_id] logger.debug(f"after add_one_req, batched_requests size: {len(batched_requests)}") + for req in batched_requests: + logger.debug(f"after add_one_req, can_run_list, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") h_or_tokens_list = [] block_tables_list = [] @@ -644,9 +650,16 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A h_or_tokens_list.append(req.input_ids) actual_processed_lengths_list.append(len(req.input_ids)) else: + logger.debug(f"intermediate peer, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") if matched_tokens > 0 and self.enable_prefix_cache: # Skip the prefix hidden states that correspond to cached tokens - new_hidden = req.hidden_states[matched_tokens:] + # 使用 chunked_rid 判断:最后一个 chunk 时 add_chunked_req 已将 chunked_req 置为 None,但本 batch 中该请求仍应按 chunked 分支用完整 hidden_states + is_chunked_req_in_batch = (self.chunked_req is not None and req.request_id == self.chunked_req.rid) or (chunked_rid is not None and req.request_id == chunked_rid) + if is_chunked_req_in_batch: + keep_len = req.total_length - matched_tokens + new_hidden = req.hidden_states[-keep_len:] + else: + new_hidden = req.hidden_states[matched_tokens:] if new_hidden.shape[0] == 0: # All tokens cached - keep the last hidden state new_hidden = req.hidden_states[-1:] @@ -664,12 +677,18 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A # For prefill, context length after this step will be total_length context_lengths_list.append(req.total_length) + for i, req in enumerate(batched_requests): + logger.debug(f"before pad_inputs, req {req.request_id} h_or_tokens_list length: {len(h_or_tokens_list[i])}") + if self.is_first_peer: padded_inputs, padding_mask = pad_inputs( self.pad_token_id, h_or_tokens_list, self.dtype ) else: padded_inputs, padding_mask = pad_inputs(0, h_or_tokens_list, self.dtype) + + for i, req in enumerate(batched_requests): + logger.debug(f"after pad_inputs, req {req.request_id} h_or_tokens_list length: {len(h_or_tokens_list[i])}") # Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len) max_len = padded_inputs.shape[1] diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 086fadb0..ae77581e 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -286,6 +286,8 @@ def admit_requests(self): logger.debug( f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" ) + if req.hidden_states is not None: + logger.debug(f"Admitted request {rid} to running requests, shape: {req.hidden_states.shape}") # Reflect current running requests metric after admission try: From 719eb96091c0e779235827c7eb9b9393cc421dce Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 2 Feb 2026 23:41:45 -0800 Subject: [PATCH 46/50] feat[chunked-prefill]: the max chunked-prefill --- src/parallax/server/executor/base_executor.py | 2 - src/parallax/server/executor/mlx_executor.py | 110 +++++++++++------- src/parallax/server/scheduler.py | 15 ++- src/parallax/utils/mac_prefill_addr.py | 41 ++++--- 4 files changed, 105 insertions(+), 63 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 392f1577..08a77267 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -304,8 +304,6 @@ def recv_requests_from_peer(self) -> Tuple[List[Request], str]: logger.debug( f"recv request {req.request_id} hidden_states.length: {hidden_size}" ) - shape = req.hidden_states.shape - logger.debug(f"recv request {req.request_id} hidden_states.shape: {shape}") if req.hidden_states.dtype != self.dtype: logger.debug( f"Converting hidden_states dtype from {req.hidden_states.dtype} to {self.dtype} for request {req.request_id}" diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 81d23163..21dd7474 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -20,6 +20,7 @@ ) from parallax.server.sampling.sampler import SamplingBatchInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.mac_prefill_addr import AddReqResult, MACPrefillAdder from parallax.utils.utils import ( combine_padding_and_causal_masks, create_causal_mask, @@ -28,7 +29,6 @@ pad_inputs, ) from parallax_utils.logging_config import get_logger -from parallax.utils.mac_prefill_addr import AddReqResult, MACPrefillAdder logger = get_logger(__name__) @@ -237,9 +237,15 @@ def __init__( self.enable_prefix_cache = enable_prefix_cache if chunked_prefill_size is not None and chunked_prefill_size > 0: # up align to page size - self.chunked_prefill_size = (chunked_prefill_size + self.cache_manager.block_size - 1) // self.cache_manager.block_size * self.cache_manager.block_size + self.chunked_prefill_size = ( + (chunked_prefill_size + self.cache_manager.block_size - 1) + // self.cache_manager.block_size + * self.cache_manager.block_size + ) else: - self.chunked_prefill_size = None + self.chunked_prefill_size = ( + max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch + ) self.chunked_req = None logger.debug( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" @@ -303,7 +309,11 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo continue if not req.abort and req.next_token_id is not None: - if self.chunked_req is not None and req.request_id == self.chunked_req.rid and self.chunked_req.is_chunked > 0: + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): original_req.status = RequestStatus.PREFILLING else: original_req.commit_new_token(req.next_token_id) @@ -315,8 +325,11 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo if req.abort: original_req.abort = True - - if self.chunked_req is not None and req.request_id == self.chunked_req.rid and self.chunked_req.is_chunked > 0: + if ( + self.chunked_req is not None + and req.request_id == self.chunked_req.rid + and self.chunked_req.is_chunked > 0 + ): self.chunked_req.is_chunked -= 1 self.cache_manager.release_request(original_req.request_id) self.scheduler.enque_request(original_req) @@ -392,16 +405,22 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo else: # This is an active request, add it to the scheduler queue to be processed. self.scheduler.enque_request(req) - - def prepare_next_batch_requests(self, requests: List[Request], batch_output: Any, context_lengths: Any) -> Tuple[List[Request], List[Request]]: + + def prepare_next_batch_requests( + self, requests: List[Request], batch_output: Any, context_lengths: Any + ) -> Tuple[List[Request], List[Request]]: """Prepares a batch of requests for the next stage of the pipeline.""" - base_chunked, base_to_forward = super().prepare_next_batch_requests(requests, batch_output, context_lengths) + base_chunked, base_to_forward = super().prepare_next_batch_requests( + requests, batch_output, context_lengths + ) if ( - self.chunked_req is None - or self.chunked_req.is_chunked <= 0 + self.chunked_req is None + or self.chunked_req.is_chunked <= 0 or self.chunked_req.rid not in [req.request_id for req in requests] ): - logger.debug(f"mlx_executor: prepare_next_batch_requests: return base_chunked{len(base_chunked)} and base_to_forward{len(base_to_forward)} because chunked_req is None or is_chunked <= 0 or rid not in requests") + logger.debug( + f"mlx_executor: prepare_next_batch_requests: return base_chunked{len(base_chunked)} and base_to_forward{len(base_to_forward)} because chunked_req is None or is_chunked <= 0 or rid not in requests" + ) return base_chunked, base_to_forward chunked_rid = self.chunked_req.rid for req in base_to_forward: @@ -410,7 +429,9 @@ def prepare_next_batch_requests(self, requests: List[Request], batch_output: Any base_chunked.append(req) break base_to_forward = [req for req in base_to_forward if req.request_id != chunked_rid] - logger.debug(f"mlx_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests") + logger.debug( + f"mlx_executor: prepare_next_batch_requests: return new_chunked{len(base_chunked)} and new_to_forward{len(base_to_forward)} because chunked_req is not None and is_chunked > 0 and rid in requests" + ) return base_chunked, base_to_forward def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): @@ -561,27 +582,29 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A batch_size = len(batched_requests) if batch_size == 0: return None - + original_batched_requests = batched_requests logger.debug(f"original_batched_requests_size: {len(original_batched_requests)}") - for req in original_batched_requests: - logger.debug(f"before prepare_prefill_batch, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") adder = MACPrefillAdder( - self.cache_manager.block_size, - self.chunked_prefill_size, - self.cache_manager - ) - + self.cache_manager.block_size, self.chunked_prefill_size, self.cache_manager + ) + chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None - - if self.chunked_req is not None and chunked_rid in [req.request_id for req in original_batched_requests]: + + if self.chunked_req is not None and chunked_rid in [ + req.request_id for req in original_batched_requests + ]: # update chunked_req use old_req - self.chunked_req = [req for req in original_batched_requests if req.request_id == chunked_rid][0] + self.chunked_req = [ + req for req in original_batched_requests if req.request_id == chunked_rid + ][0] logger.debug(f"before add_chunked_req, chunked_req is not None") self.chunked_req = adder.add_chunked_req(self.chunked_req) if self.chunked_req is None: - logger.debug(f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None") - + logger.debug( + f"chunked_req{chunked_rid} has chunk all after add_chunked_req, chunked_req is None" + ) + for old_req in original_batched_requests: if chunked_rid is not None and old_req.request_id == chunked_rid: continue @@ -589,19 +612,23 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A if res != AddReqResult.CONTINUE: logger.debug(f"macprefilladder has no token to add to prefill batch") break - + if adder.new_chunked_req is not None: self.chunked_req = adder.new_chunked_req logger.debug(f"new chunked_req is {self.chunked_req.rid}") - - if self.chunked_req is not None and self.chunked_req.rid in [req.request_id for req in original_batched_requests]: + + if self.chunked_req is not None and self.chunked_req.rid in [ + req.request_id for req in original_batched_requests + ]: self.chunked_req.is_chunked += 1 - + can_run_by_id = {req.request_id: req for req in adder.can_run_list} - batched_requests = [can_run_by_id[req.request_id] for req in original_batched_requests if req.request_id in can_run_by_id] + batched_requests = [ + can_run_by_id[req.request_id] + for req in original_batched_requests + if req.request_id in can_run_by_id + ] logger.debug(f"after add_one_req, batched_requests size: {len(batched_requests)}") - for req in batched_requests: - logger.debug(f"after add_one_req, can_run_list, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") h_or_tokens_list = [] block_tables_list = [] @@ -617,7 +644,9 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A token_ids = None if self.enable_prefix_cache and req.input_ids is not None: token_ids = req.input_ids - logger.debug(f"before allocate_request: {req.request_id}, token_ids length: {len(token_ids)}, req.total_length: {req.total_length}") + logger.debug( + f"before allocate_request: {req.request_id}, req.total_length: {req.total_length}" + ) success, matched_tokens = self.cache_manager.allocate_request( req.request_id, req.total_length, token_ids=token_ids ) @@ -650,11 +679,12 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A h_or_tokens_list.append(req.input_ids) actual_processed_lengths_list.append(len(req.input_ids)) else: - logger.debug(f"intermediate peer, req {req.request_id} hidden_states.shape: {req.hidden_states.shape}") if matched_tokens > 0 and self.enable_prefix_cache: # Skip the prefix hidden states that correspond to cached tokens # 使用 chunked_rid 判断:最后一个 chunk 时 add_chunked_req 已将 chunked_req 置为 None,但本 batch 中该请求仍应按 chunked 分支用完整 hidden_states - is_chunked_req_in_batch = (self.chunked_req is not None and req.request_id == self.chunked_req.rid) or (chunked_rid is not None and req.request_id == chunked_rid) + is_chunked_req_in_batch = ( + self.chunked_req is not None and req.request_id == self.chunked_req.rid + ) or (chunked_rid is not None and req.request_id == chunked_rid) if is_chunked_req_in_batch: keep_len = req.total_length - matched_tokens new_hidden = req.hidden_states[-keep_len:] @@ -677,18 +707,12 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A # For prefill, context length after this step will be total_length context_lengths_list.append(req.total_length) - for i, req in enumerate(batched_requests): - logger.debug(f"before pad_inputs, req {req.request_id} h_or_tokens_list length: {len(h_or_tokens_list[i])}") - if self.is_first_peer: padded_inputs, padding_mask = pad_inputs( self.pad_token_id, h_or_tokens_list, self.dtype ) else: padded_inputs, padding_mask = pad_inputs(0, h_or_tokens_list, self.dtype) - - for i, req in enumerate(batched_requests): - logger.debug(f"after pad_inputs, req {req.request_id} h_or_tokens_list length: {len(h_or_tokens_list[i])}") # Generate slot_mapping for prefill (only for NEW tokens, starting from prefix_len) max_len = padded_inputs.shape[1] @@ -745,7 +769,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A "h_or_tokens": padded_inputs, "cache": self.cache_manager.get_caches(), "mask": mask, - "requests": original_batched_requests, + "requests": batched_requests, "block_tables": block_tables_tensor, "context_lengths": context_lengths_tensor, "slot_mapping": slot_mapping_tensor, diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index ae77581e..709cae77 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -254,9 +254,19 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: + logger.debug(f"Request {rid} already in running requests, adding back to wait queue.") self._wait_queue.append(req) continue + # if cache_manager is not None and has allow attribute, break if allow is False + if ( + self.cache_manager is not None + and hasattr(self.cache_manager, "allow") + and not self.cache_manager.allow + ): + logger.debug(f"Cache manager does not allow request {rid},breaking admission.") + break + # # Check kv cache pool # if self.cache_manager is not None: # if not self.cache_manager.has_request(req.request_id): @@ -286,8 +296,6 @@ def admit_requests(self): logger.debug( f"Admitted to running: rid={rid}, status={req.status}, running_size={len(self._running_requests)}, ready={req.ready_for_next_step}" ) - if req.hidden_states is not None: - logger.debug(f"Admitted request {rid} to running requests, shape: {req.hidden_states.shape}") # Reflect current running requests metric after admission try: @@ -330,7 +338,6 @@ def form_batch(self) -> List[Request]: self.admit_requests() if not self._running_requests: return [] - inflight_tokens = 0 batch: List[Request] = [] @@ -343,13 +350,13 @@ def form_batch(self) -> List[Request]: prefill_candidates.append(req) elif req.is_decoding: decode_candidates.append(req) - # 1) Fill with prefills first for req in prefill_candidates: if len(batch) >= self.micro_batch_size: break cost = req.prompt_len if cost + inflight_tokens > self.max_num_tokens_per_batch: + logger.debug(f"prefill request {req.request_id} cost {cost} + inflight_tokens {inflight_tokens} > max_num_tokens_per_batch {self.max_num_tokens_per_batch}, breaking") continue batch.append(req) inflight_tokens += cost diff --git a/src/parallax/utils/mac_prefill_addr.py b/src/parallax/utils/mac_prefill_addr.py index 32ffc5f8..199c63ec 100644 --- a/src/parallax/utils/mac_prefill_addr.py +++ b/src/parallax/utils/mac_prefill_addr.py @@ -1,6 +1,6 @@ from enum import Enum, auto +from typing import Optional -from mpmath import extend from parallax.server.cache_manager import CacheManager from parallax.server.request import Request @@ -10,37 +10,47 @@ class AddReqResult(Enum): NO_TOKEN = auto() # No token left OTHER = auto() # Other reasons to stop adding requests + class MACPrefillAdder: """ MACPrefillAdder is a class that adds prefill requests to the MAC prefill batch. """ + def __init__( self, page_size: int, - rem_chunk_tokens: int, - cache_manager: CacheManager + rem_chunk_tokens: Optional[int], + cache_manager: CacheManager, ): self.page_size = page_size self.rem_chunk_tokens = rem_chunk_tokens self.can_run_list = [] self.new_chunked_req = None self.cache_manager = cache_manager - def add_chunked_req(self, chunked_req: Request) -> Request: + + def add_chunked_req(self, chunked_req: Request) -> Optional[Request]: if chunked_req is None: return None matched_tokens = 0 if self.cache_manager.prefix_cache is not None: - _, matched_tokens = self.cache_manager.prefix_cache.match_prefix(chunked_req.origin_input_ids) + _, matched_tokens = self.cache_manager.prefix_cache.match_prefix( + chunked_req.origin_input_ids + ) extend_input_len = len(chunked_req.origin_input_ids) - matched_tokens extend_input_len = 1 if extend_input_len <= 0 else extend_input_len - truncated = extend_input_len > self.rem_chunk_tokens - chunked_req_offset = min(self.rem_chunk_tokens, extend_input_len) + matched_tokens - chunked_req.input_ids = chunked_req.origin_input_ids[: chunked_req_offset] + if self.rem_chunk_tokens is None: + truncated = False + chunked_req_offset = len(chunked_req.origin_input_ids) + else: + truncated = extend_input_len > self.rem_chunk_tokens + chunked_req_offset = min(self.rem_chunk_tokens, extend_input_len) + matched_tokens + chunked_req.input_ids = chunked_req.origin_input_ids[:chunked_req_offset] chunked_req._effective_total_length = chunked_req_offset self.can_run_list.append(chunked_req) - self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= min(self.rem_chunk_tokens, extend_input_len) return chunked_req if truncated else None - + def add_one_req(self, req: Request) -> AddReqResult: matched_tokens = 0 if self.cache_manager.prefix_cache is not None: @@ -48,10 +58,13 @@ def add_one_req(self, req: Request) -> AddReqResult: extend_input_len = len(req.origin_input_ids) - matched_tokens extend_input_len = 1 if extend_input_len <= 0 else extend_input_len # align to page size - extend_input_len = (extend_input_len + self.page_size - 1) // self.page_size * self.page_size + extend_input_len = ( + (extend_input_len + self.page_size - 1) // self.page_size * self.page_size + ) if self.rem_chunk_tokens is None or extend_input_len <= self.rem_chunk_tokens: self.can_run_list.append(req) - self.rem_chunk_tokens -= extend_input_len + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= extend_input_len else: # make sure at least one page is available trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size @@ -59,11 +72,11 @@ def add_one_req(self, req: Request) -> AddReqResult: return AddReqResult.OTHER extend_input_len = trunc_len chunked_req_offset = extend_input_len + matched_tokens - req.input_ids = req.origin_input_ids[: chunked_req_offset] + req.input_ids = req.origin_input_ids[:chunked_req_offset] req._effective_total_length = chunked_req_offset self.can_run_list.append(req) self.new_chunked_req = req self.rem_chunk_tokens -= extend_input_len if self.rem_chunk_tokens is None or self.rem_chunk_tokens <= 0: return AddReqResult.OTHER - return AddReqResult.CONTINUE \ No newline at end of file + return AddReqResult.CONTINUE From 634269762c0604b082a1b5322316eed69743bba6 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Tue, 3 Feb 2026 00:45:00 -0800 Subject: [PATCH 47/50] fix[prefix-cache]: use correct hidden_states --- src/parallax/server/executor/mlx_executor.py | 16 ++++++++-------- src/parallax/server/executor/sglang_executor.py | 4 ++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 21dd7474..e9c0d184 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -682,14 +682,14 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A if matched_tokens > 0 and self.enable_prefix_cache: # Skip the prefix hidden states that correspond to cached tokens # 使用 chunked_rid 判断:最后一个 chunk 时 add_chunked_req 已将 chunked_req 置为 None,但本 batch 中该请求仍应按 chunked 分支用完整 hidden_states - is_chunked_req_in_batch = ( - self.chunked_req is not None and req.request_id == self.chunked_req.rid - ) or (chunked_rid is not None and req.request_id == chunked_rid) - if is_chunked_req_in_batch: - keep_len = req.total_length - matched_tokens - new_hidden = req.hidden_states[-keep_len:] - else: - new_hidden = req.hidden_states[matched_tokens:] + # is_chunked_req_in_batch = ( + # self.chunked_req is not None and req.request_id == self.chunked_req.rid + # ) or (chunked_rid is not None and req.request_id == chunked_rid) + # if is_chunked_req_in_batch: + keep_len = req.total_length - matched_tokens + new_hidden = req.hidden_states[-keep_len:] + # else: + # new_hidden = req.hidden_states[matched_tokens:] if new_hidden.shape[0] == 0: # All tokens cached - keep the last hidden state new_hidden = req.hidden_states[-1:] diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index d4c9421f..aad72f20 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -113,6 +113,10 @@ def __init__( f"clamping to {kv_block_size}" ) self.chunked_prefill_size = kv_block_size + elif self.chunked_prefill_size is not None: + self.chunked_prefill_size = chunked_prefill_size + else: + self.chunked_prefill_size = max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch if self.lora_paths is not None and len(self.lora_paths) > 0: self.check_lora_server_args() From eac6229cd8bfbd2d326f2dab9dbbd5da247e2160 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Tue, 3 Feb 2026 00:57:20 -0800 Subject: [PATCH 48/50] fix[chunked-prefill]: delete kig --- src/parallax/server/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 709cae77..f139e329 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -254,7 +254,6 @@ def admit_requests(self): req = self._wait_queue.popleft() rid = req.request_id if rid in self._running_requests: - logger.debug(f"Request {rid} already in running requests, adding back to wait queue.") self._wait_queue.append(req) continue From c2091cf73db01a23a15c733b87c4536764187a2f Mon Sep 17 00:00:00 2001 From: wasamtc Date: Tue, 3 Feb 2026 16:46:46 -0800 Subject: [PATCH 49/50] fix[chunked-prefill]: fix format --- src/parallax/server/executor/sglang_executor.py | 4 +++- src/parallax/server/scheduler.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index f14abc4f..8f18507d 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -116,7 +116,9 @@ def __init__( elif self.chunked_prefill_size is not None: self.chunked_prefill_size = chunked_prefill_size else: - self.chunked_prefill_size = max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch + self.chunked_prefill_size = ( + max_sequence_length if max_sequence_length is not None else max_num_tokens_per_batch + ) if self.lora_paths is not None and len(self.lora_paths) > 0: self.check_lora_server_args() diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index f139e329..4b885c4c 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -355,7 +355,9 @@ def form_batch(self) -> List[Request]: break cost = req.prompt_len if cost + inflight_tokens > self.max_num_tokens_per_batch: - logger.debug(f"prefill request {req.request_id} cost {cost} + inflight_tokens {inflight_tokens} > max_num_tokens_per_batch {self.max_num_tokens_per_batch}, breaking") + logger.debug( + f"prefill request {req.request_id} cost {cost} + inflight_tokens {inflight_tokens} > max_num_tokens_per_batch {self.max_num_tokens_per_batch}, breaking" + ) continue batch.append(req) inflight_tokens += cost From f68fa32fea2bf5908a6434e522dee59405ac0a50 Mon Sep 17 00:00:00 2001 From: wasamtc Date: Mon, 9 Feb 2026 01:12:54 -0800 Subject: [PATCH 50/50] delete comments and add verify for chunked-prfill param --- src/parallax/server/executor/mlx_executor.py | 13 ------------ src/parallax/server/scheduler.py | 22 -------------------- src/parallax/server/server_args.py | 10 +++++++++ 3 files changed, 10 insertions(+), 35 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index e9c0d184..6fcf2c2c 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -378,11 +378,6 @@ def handle_input_requests(self, requests: List[Request], from_previous_peer: boo req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - # if self.enable_prefix_cache: - # keys, values = self.cache_manager.gather_kv_cache(req.request_id) - # self.prefix_cache.cache_finished_request(req, keys, values) - # self.prefix_cache.evict_request(req.request_id) - self.cache_manager.release_request(req.request_id) logger.debug( f"Released resources for finished request {req.request_id}, " @@ -680,16 +675,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A actual_processed_lengths_list.append(len(req.input_ids)) else: if matched_tokens > 0 and self.enable_prefix_cache: - # Skip the prefix hidden states that correspond to cached tokens - # 使用 chunked_rid 判断:最后一个 chunk 时 add_chunked_req 已将 chunked_req 置为 None,但本 batch 中该请求仍应按 chunked 分支用完整 hidden_states - # is_chunked_req_in_batch = ( - # self.chunked_req is not None and req.request_id == self.chunked_req.rid - # ) or (chunked_rid is not None and req.request_id == chunked_rid) - # if is_chunked_req_in_batch: keep_len = req.total_length - matched_tokens new_hidden = req.hidden_states[-keep_len:] - # else: - # new_hidden = req.hidden_states[matched_tokens:] if new_hidden.shape[0] == 0: # All tokens cached - keep the last hidden state new_hidden = req.hidden_states[-1:] diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 4b885c4c..1c9afef2 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -266,28 +266,6 @@ def admit_requests(self): logger.debug(f"Cache manager does not allow request {rid},breaking admission.") break - # # Check kv cache pool - # if self.cache_manager is not None: - # if not self.cache_manager.has_request(req.request_id): - # # TODO: Handle chunked prefill, and support preemption. - # # Pass input_ids for prefix cache matching - # token_ids = getattr(req, "input_ids", None) - # success, matched_tokens = self.cache_manager.allocate_request( - # req.request_id, req.total_length, token_ids=token_ids - # ) - # if not success: - # logger.warning( - # f"Request {rid} can't be admit to running batch due to KV cache size." - # ) - # # Put back to wait queue if allocation fails - # self._wait_queue.appendleft(req) - # # Stop admitting since we are out of memory - # break - # if matched_tokens > 0: - # logger.debug( - # f"Request {rid} matched {matched_tokens} tokens from prefix cache" - # ) - # Add request to running requests self._running_requests[rid] = req # Initialize timing for timeout enforcement diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 80b875a0..445bb487 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -355,6 +355,16 @@ def validate_args(args: argparse.Namespace) -> None: if args.kv_block_size <= 0: raise ValueError("kv_block_size must be positive") + # chunked-prefill 依赖 prefix-cache,未开启 prefix-cache 时不能单独使用 chunked-prefill + chunked_prefill_size = getattr(args, "chunked_prefill_size", None) + if chunked_prefill_size is not None and not args.enable_prefix_cache: + raise ValueError( + "chunked-prefill requires prefix-cache to be enabled. " + "Use --enable-prefix-cache when specifying --chunked-prefill-size." + ) + if chunked_prefill_size is not None and chunked_prefill_size <= 0: + raise ValueError("chunked_prefill_size must be positive") + if args.micro_batch_ratio <= 0: raise ValueError("micro_batch_ratio must be positive")