From 4f96e468e3d72f19cb4d6d0c46b340510c5b0017 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 2 Feb 2026 09:24:43 +0800 Subject: [PATCH 01/36] init --- src/parallax/server/executor/base_executor.py | 13 +++++++++++++ src/parallax/server/request.py | 4 ++++ src/parallax/sglang/batch_info.py | 1 + 3 files changed, 18 insertions(+) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 7180242c..ffdf49fa 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -660,6 +660,18 @@ def _handle_raw_request(self, raw_request: Dict): lora_path = raw_request.get("lora_path") return_probs = raw_request.get("return_probs", False) # Get return_probs parameter + # Extract multimodal params if present + multimodal_params = None + if "messages" in raw_request: + for message in raw_request["messages"]: + content = message.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + if multimodal_params is None: + multimodal_params = {"images": []} + multimodal_params["images"].append(part["image_url"]) + raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: sampling_params = SamplingParams() @@ -684,6 +696,7 @@ def _handle_raw_request(self, raw_request: Dict): max_total_length=max_total_length, lora_path=lora_path, return_probs=return_probs, + multimodal_params=multimodal_params, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f7c9bb90..50c1ed47 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -96,6 +96,7 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, + multimodal_params: Optional[Dict] = None, ): self.request_id = request_id or str(uuid.uuid4()) self.status = status @@ -109,6 +110,7 @@ def __init__( self.last_updated_time: Optional[float] = None self.lora_id: Optional[str] = None self.lora_path = lora_path + self.multimodal_params = multimodal_params @property def is_finished(self) -> bool: @@ -161,6 +163,7 @@ def __init__( status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, return_probs: bool = False, + multimodal_params: Optional[Dict] = None, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -171,6 +174,7 @@ def __init__( input_ids=input_ids, sampling_params=sampling_params, lora_path=lora_path, + multimodal_params=multimodal_params, ) self.prompt = prompt self.return_probs = return_probs diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 56720040..ec46589f 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -60,6 +60,7 @@ def transform_requests_to_sglang( origin_input_ids=old_req.input_ids, sampling_params=sampling_params, lora_id=old_req.lora_id, + multimodal_inputs=old_req.multimodal_params, ) # Debug: Log before cache lookup From 7d6f7188a7eb435b16445ea4df925fbfc82b708a Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 2 Feb 2026 10:46:47 +0800 Subject: [PATCH 02/36] update --- .../server/executor/sglang_executor.py | 3 +- src/parallax/sglang/batch_info.py | 74 ++++++++++++++++++- src/parallax/sglang/model_runner.py | 12 ++- src/parallax/vllm/batch_info.py | 22 +++++- 4 files changed, 103 insertions(+), 8 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index e4a7996f..8485fdc2 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -140,7 +140,7 @@ def __init__( logger.debug( f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( + self.model_runner, self.config, self.tokenizer, self.processor = initialize_sgl_model_runner( **model_runner_params ) logger.debug( @@ -601,6 +601,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A batched_requests, self.model_runner, self.page_tree_cache, + self.processor, ) self.cur_batch = schedule_batch diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index ec46589f..b31d26e0 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -48,9 +48,13 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S def transform_requests_to_sglang( - old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None + old_requests: List[Request], + page_tree_cache: Optional[PageRadixCache] = None, + processor: Optional[Any] = None, ) -> List[Req]: """Transforms Parallax Request to SGLang.Req format""" + from sglang.srt.managers.schedule_batch import MultimodalInputs + reqs = [] for old_req in old_requests: sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) @@ -60,8 +64,71 @@ def transform_requests_to_sglang( origin_input_ids=old_req.input_ids, sampling_params=sampling_params, lora_id=old_req.lora_id, - multimodal_inputs=old_req.multimodal_params, ) + if old_req.multimodal_params is not None: + # Construct MultimodalInputs from dict + try: + if "mm_items" in old_req.multimodal_params: + # Case 1: Already structured data + req.multimodal_inputs = MultimodalInputs.from_dict(old_req.multimodal_params) + elif "images" in old_req.multimodal_params and processor is not None: + # Case 2: List of image URLs, need processing + # Import necessary modules here to avoid top-level dependency + from PIL import Image + import requests + from io import BytesIO + from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality + + image_urls = old_req.multimodal_params["images"] + mm_items = [] + + for url in image_urls: + try: + # Basic image downloading logic + if url.startswith("http"): + response = requests.get(url, timeout=10) + response.raise_for_status() + image_data = BytesIO(response.content) + image = Image.open(image_data).convert("RGB") + elif url.startswith("data:image"): + # Handle base64 encoded images if needed, or skip + continue + else: + # Assume local path + image = Image.open(url).convert("RGB") + + # Process image using the provided processor + # Note: Different processors have different call signatures. + # Standard HF processor usage: + inputs = processor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values + + # Construct MultimodalDataItem + # NOTE: SGLang expects features to be stored in specific fields depending on modality + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_values, + ) + mm_items.append(item) + except Exception as img_err: + logger.error(f"Failed to process image {url}: {img_err}") + + if mm_items: + req.multimodal_inputs = MultimodalInputs(mm_items=mm_items) + logger.debug(f"Successfully processed {len(mm_items)} images for request {req.rid}") + + else: + # Fallback + logger.warning( + f"Assigning raw multimodal_params to req.multimodal_inputs. " + f"SGLang might expect MultimodalInputs object with Tensors. " + f"Params keys: {old_req.multimodal_params.keys()}, Processor: {processor is not None}" + ) + req.multimodal_inputs = old_req.multimodal_params + + except Exception as e: + logger.warning(f"Failed to construct MultimodalInputs: {e}") + req.multimodal_inputs = old_req.multimodal_params # Debug: Log before cache lookup if page_tree_cache is not None: @@ -93,10 +160,11 @@ def form_sgl_batch_prefill( requests: List[Request], model_runner: ModelRunner, page_tree_cache: Optional[PageRadixCache] = None, + processor: Optional[Any] = None, ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache) + sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache, processor) def dummy_evict(*args): pass diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0c2667a9..ee06ca7e 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -278,6 +278,7 @@ def initialize_sgl_model_runner( - model_runner: SGL model runner - config: model config driven by mlx-lm - tokenizer: tokenizer driven by mlx-lm + - processor: optional processor for multimodal models """ apply_parallax_sglang_monkey_patch() @@ -302,6 +303,15 @@ def initialize_sgl_model_runner( tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") + # Load processor if available (for multimodal models) + processor = None + try: + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + logger.info(f"Loaded processor for model {model_path}") + except Exception as e: + logger.debug(f"No processor loaded (normal for text-only models): {e}") + if nccl_port is None: nccl_port = random.randint(4000, 5000) @@ -374,7 +384,7 @@ def initialize_sgl_model_runner( dp_rank=dp_rank, dp_size=dp_size, ) - return model_runner, config, tokenizer + return model_runner, config, tokenizer, processor def refit_sgl_model( diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index d7ea47b3..7b23f87c 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -141,6 +141,7 @@ def _build_vllm_request( model_runner: Any, *, include_outputs: bool, + multimodal_params: Optional[Dict] = None, ) -> VLLMRequest: block_hasher = getattr(model_runner, "request_block_hasher", None) @@ -171,6 +172,7 @@ def _build_vllm_request( arrival_time=getattr(req, "arrival_time", 0.0), block_hasher=block_hasher, lora_request=lora_req, + multi_modal_data=multimodal_params, ) if include_outputs: output_ids = getattr(req, "output_ids", None) or [] @@ -204,8 +206,15 @@ def form_vllm_batch_prefill( for req in batched_requests: sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - - vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=False) + multimodal_params = getattr(req, "multimodal_params", None) + + vllm_req = _build_vllm_request( + req, + sampling_params, + model_runner, + include_outputs=False, + multimodal_params=multimodal_params, + ) created_vllm_requests.append(vllm_req) computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) @@ -328,7 +337,14 @@ def form_vllm_batch_decode( new_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) + multimodal_params = getattr(req, "multimodal_params", None) + vllm_req = _build_vllm_request( + req, + sampling_params, + model_runner, + include_outputs=True, + multimodal_params=multimodal_params, + ) prompt_ids = getattr(req, "input_ids", None) or [] # For decode stage, computed_token_count should be the total number of tokens From 4358906e6d1335d69aa4dfa9e0e5861548cd9eb1 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 2 Feb 2026 10:48:09 +0800 Subject: [PATCH 03/36] update --- src/parallax/sglang/batch_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index b31d26e0..8d5e49c5 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -6,7 +6,7 @@ """ from types import SimpleNamespace -from typing import List, Optional +from typing import List, Optional, Any import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch From 104a8a63134dbae071dfc7b14013e6933fb15620 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 2 Feb 2026 11:30:08 +0000 Subject: [PATCH 04/36] tmp add --- src/parallax/launch.py | 14 +++++++++++--- src/parallax/server/request.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index b20f6250..75054150 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -120,23 +120,31 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr check_latest_release() config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache) + num_layers = config.get("num_hidden_layers") or config.get("n_layer") or config.get("num_layers") + + # If not found in top level, check text_config (common in multimodal models) + if num_layers is None and "text_config" in config: + text_config = config["text_config"] + if isinstance(text_config, dict): + num_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or text_config.get("num_layers") + if args.start_layer is None: args.start_layer = 0 if args.end_layer is None: - args.end_layer = config.get("num_hidden_layers") + args.end_layer = num_layers # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) # Launch P2P server as subprocess - if not (args.start_layer == 0 and args.end_layer == config.get("num_hidden_layers")): + if not (args.start_layer == 0 and args.end_layer == num_layers): p2p_server_process = launch_p2p_server_process( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, - hidden_layers=config.get("num_hidden_layers"), + hidden_layers=num_layers, tp_size=args.tp_size, dp_size=args.dp_size, tcp_port=args.tcp_port, diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 50c1ed47..5354dae6 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -60,7 +60,7 @@ import uuid from enum import Enum -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from parallax.server.sampling.sampling_params import SamplingParams from parallax_utils.logging_config import get_logger From 872384145a85d33bc4f810c1931a10f1676abdee Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 2 Feb 2026 11:30:48 +0000 Subject: [PATCH 05/36] 11 --- src/parallax/sglang/batch_info.py | 93 ++++++++++++++++++------------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 8d5e49c5..aeb791b6 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -7,9 +7,18 @@ from types import SimpleNamespace from typing import List, Optional, Any +import requests +from io import BytesIO +from PIL import Image import torch -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + Req, + ScheduleBatch, + MultimodalInputs, + MultimodalDataItem, + Modality, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -28,6 +37,48 @@ logger = get_logger(__name__) +def _process_images(image_urls: List[str], processor: Any) -> List[MultimodalDataItem]: + mm_items = [] + for url in image_urls: + try: + if url.startswith("http"): + response = requests.get(url, timeout=10) + response.raise_for_status() + image_data = BytesIO(response.content) + image = Image.open(image_data).convert("RGB") + elif url.startswith("data:image"): + # TODO: Handle base64 + continue + else: + image = Image.open(url).convert("RGB") + + # Process image + inputs = processor(images=image, return_tensors="pt") + + # Extract features (adapt based on processor output) + pixel_values = inputs.get("pixel_values") + if pixel_values is None: + logger.error(f"Processor output missing pixel_values for {url}") + continue + + # Extract extra fields if available + model_specific_data = {} + if "image_grid_thw" in inputs: + model_specific_data["image_grid_thw"] = inputs["image_grid_thw"] + if "image_sizes" in inputs: + model_specific_data["image_sizes"] = inputs["image_sizes"] + + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_values, + model_specific_data=model_specific_data, + ) + mm_items.append(item) + except Exception as e: + logger.error(f"Failed to process image {url}: {e}") + return mm_items + + def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> SGLSamplingParams: """Transforms Parallax SamplingParams to SGLang.SamplingParams format""" params = SGLSamplingParams( @@ -53,7 +104,6 @@ def transform_requests_to_sglang( processor: Optional[Any] = None, ) -> List[Req]: """Transforms Parallax Request to SGLang.Req format""" - from sglang.srt.managers.schedule_batch import MultimodalInputs reqs = [] for old_req in old_requests: @@ -73,45 +123,8 @@ def transform_requests_to_sglang( req.multimodal_inputs = MultimodalInputs.from_dict(old_req.multimodal_params) elif "images" in old_req.multimodal_params and processor is not None: # Case 2: List of image URLs, need processing - # Import necessary modules here to avoid top-level dependency - from PIL import Image - import requests - from io import BytesIO - from sglang.srt.managers.schedule_batch import MultimodalDataItem, Modality - image_urls = old_req.multimodal_params["images"] - mm_items = [] - - for url in image_urls: - try: - # Basic image downloading logic - if url.startswith("http"): - response = requests.get(url, timeout=10) - response.raise_for_status() - image_data = BytesIO(response.content) - image = Image.open(image_data).convert("RGB") - elif url.startswith("data:image"): - # Handle base64 encoded images if needed, or skip - continue - else: - # Assume local path - image = Image.open(url).convert("RGB") - - # Process image using the provided processor - # Note: Different processors have different call signatures. - # Standard HF processor usage: - inputs = processor(images=image, return_tensors="pt") - pixel_values = inputs.pixel_values - - # Construct MultimodalDataItem - # NOTE: SGLang expects features to be stored in specific fields depending on modality - item = MultimodalDataItem( - modality=Modality.IMAGE, - feature=pixel_values, - ) - mm_items.append(item) - except Exception as img_err: - logger.error(f"Failed to process image {url}: {img_err}") + mm_items = _process_images(image_urls, processor) if mm_items: req.multimodal_inputs = MultimodalInputs(mm_items=mm_items) From 605577adbba08adc4cc2214d2c577a231a8ba6ca Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 3 Feb 2026 02:37:44 +0000 Subject: [PATCH 06/36] done sglang --- src/parallax/launch.py | 4 +- src/parallax/server/executor/base_executor.py | 48 +++- .../server/executor/sglang_executor.py | 6 + src/parallax/sglang/batch_info.py | 269 +++++++++++++++--- src/parallax/sglang/multimodal_utils.py | 146 ++++++++++ 5 files changed, 424 insertions(+), 49 deletions(-) create mode 100644 src/parallax/sglang/multimodal_utils.py diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 75054150..a62892f3 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -120,13 +120,13 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr check_latest_release() config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache) - num_layers = config.get("num_hidden_layers") or config.get("n_layer") or config.get("num_layers") + num_layers = config.get("num_hidden_layers") # If not found in top level, check text_config (common in multimodal models) if num_layers is None and "text_config" in config: text_config = config["text_config"] if isinstance(text_config, dict): - num_layers = text_config.get("num_hidden_layers") or text_config.get("n_layer") or text_config.get("num_layers") + num_layers = text_config.get("num_hidden_layers") if args.start_layer is None: args.start_layer = 0 diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index ffdf49fa..d0511089 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -51,7 +51,6 @@ class BaseExecutor: """High-level executor for managing model shards, scheduler, and cache pool on each Peer.""" - def __init__( self, # Model Configs @@ -112,8 +111,22 @@ def __init__( # Pipe communication self.conn = conn + def _config_get(key, default=None): + if isinstance(self.config, dict): + return self.config.get(key, default) + return getattr(self.config, key, default) + + num_hidden_layers = _config_get("num_hidden_layers") + text_config = _config_get("text_config") + if num_hidden_layers is None and isinstance(text_config, dict): + num_hidden_layers = text_config.get("num_hidden_layers") + if num_hidden_layers is None: + num_hidden_layers = _config_get("n_layer") or _config_get("num_layers") + if isinstance(text_config, dict): + num_hidden_layers = text_config.get("num_hidden_layers") + self.is_first_peer = start_layer == 0 - self.is_last_peer = end_layer == self.config.get("num_hidden_layers") + self.is_last_peer = end_layer == num_hidden_layers self.tp_size = tp_size self.tp_rank = tp_rank self.dp_size = dp_size @@ -143,7 +156,25 @@ def __init__( else: self.pad_token_id = self.tokenizer.pad_token_id - self.eos_token_id = self.config.get("eos_token_id", None) + self.eos_token_id = _config_get("eos_token_id") + if self.eos_token_id is None and isinstance(text_config, dict): + self.eos_token_id = text_config.get("eos_token_id") + + vision_config = _config_get("vision_config", {}) + if not isinstance(vision_config, dict): + vision_config = { + "spatial_merge_size": getattr(vision_config, "spatial_merge_size", None), + "tokens_per_second": getattr(vision_config, "tokens_per_second", None), + } + self.mm_config = { + "model_type": _config_get("model_type"), + "image_token_id": _config_get("image_token_id"), + "vision_start_token_id": _config_get("vision_start_token_id"), + "vision_end_token_id": _config_get("vision_end_token_id"), + "video_token_id": _config_get("video_token_id"), + "audio_token_id": _config_get("audio_token_id"), + "vision_config": vision_config, + } # Scheduler: derive final max_batch_size with KV constraints # Remove this for now as it's not working on gpu devices @@ -617,7 +648,16 @@ def _handle_raw_request(self, raw_request: Dict): rid = raw_request["rid"] if self.tokenizer.chat_template: messages = raw_request["messages"] - process_message_content(messages) + has_non_text_content = any( + isinstance(msg.get("content"), list) + and any( + isinstance(part, dict) and part.get("type") != "text" + for part in msg.get("content") + ) + for msg in messages + ) + if not has_non_text_content: + process_message_content(messages) chat_template_kwargs = raw_request.get("chat_template_kwargs", {}) # check extra_body for backward compatibility if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]: diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 8485fdc2..6c94d372 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional, Tuple import torch +from sglang.srt.environ import envs +from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.cache_init_params import CacheInitParams @@ -146,6 +148,8 @@ def __init__( logger.debug( f"SGLang model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) + embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get() + init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) # Set device to specific CUDA device based on tp_rank # This ensures tensors are moved to the correct GPU @@ -602,6 +606,8 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A self.model_runner, self.page_tree_cache, self.processor, + self.mm_config, + self.tokenizer, ) self.cur_batch = schedule_batch diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index aeb791b6..06000507 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -19,6 +19,7 @@ MultimodalDataItem, Modality, ) +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -32,51 +33,128 @@ from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) +from parallax.sglang.multimodal_utils import prepare_sglang_multimodal_inputs from parallax_utils.logging_config import get_logger logger = get_logger(__name__) -def _process_images(image_urls: List[str], processor: Any) -> List[MultimodalDataItem]: +def _load_image(url: Any) -> Image.Image: + """Load a single image from URL, file path, or base64.""" + import base64 + + if isinstance(url, dict): + url = url.get("url") + if not isinstance(url, str): + raise ValueError(f"Unsupported image url type: {type(url)}") + + if url.startswith("http"): + response = requests.get(url, timeout=10) + response.raise_for_status() + image_data = BytesIO(response.content) + return Image.open(image_data).convert("RGB") + elif url.startswith("data:image"): + # Handle base64 encoded images + # Format: data:image/png;base64, + header, encoded = url.split(",", 1) + image_data = BytesIO(base64.b64decode(encoded)) + return Image.open(image_data).convert("RGB") + else: + return Image.open(url).convert("RGB") + + +def _process_images( + image_urls: List[Any], + processor: Any, + input_text: str = "", + mm_config: Optional[dict] = None, +) -> tuple[List[MultimodalDataItem], Optional[torch.Tensor], List[int]]: + if not image_urls: + return [], None, [] + mm_items = [] + images = [] + for url in image_urls: try: - if url.startswith("http"): - response = requests.get(url, timeout=10) - response.raise_for_status() - image_data = BytesIO(response.content) - image = Image.open(image_data).convert("RGB") - elif url.startswith("data:image"): - # TODO: Handle base64 - continue - else: - image = Image.open(url).convert("RGB") - - # Process image - inputs = processor(images=image, return_tensors="pt") + image = _load_image(url) + images.append(image) + except Exception as e: + logger.exception(f"Failed to load image {url}: {e}") + continue + + if not images: + return [], None, [] + + try: + inputs = processor(text=input_text, images=images, return_tensors="pt") + + if inputs is None: + logger.error("Processor returned None") + return [], None, [] + + pixel_values = inputs.get("pixel_values") + if pixel_values is None: + logger.error("Processor output missing pixel_values") + return [], None, [] + + expanded_input_ids = inputs.get("input_ids") + if expanded_input_ids is not None: + expanded_input_ids = expanded_input_ids.flatten().tolist() + else: + expanded_input_ids = [] + + logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") + + image_grid_thw = inputs.get("image_grid_thw") + image_sizes = inputs.get("image_sizes") + + model_specific_data = {} + if image_grid_thw is not None: + model_specific_data["image_grid_thw"] = image_grid_thw + if image_sizes is not None: + model_specific_data["image_sizes"] = image_sizes + + if image_grid_thw is not None and len(image_grid_thw) == len(images): + num_images = len(images) - # Extract features (adapt based on processor output) - pixel_values = inputs.get("pixel_values") - if pixel_values is None: - logger.error(f"Processor output missing pixel_values for {url}") - continue - - # Extract extra fields if available - model_specific_data = {} - if "image_grid_thw" in inputs: - model_specific_data["image_grid_thw"] = inputs["image_grid_thw"] - if "image_sizes" in inputs: - model_specific_data["image_sizes"] = inputs["image_sizes"] - + patches_per_image = [] + for i in range(num_images): + grid = image_grid_thw[i] + if isinstance(grid, torch.Tensor): + num_patches = int(torch.prod(grid).item()) + else: + num_patches = int(torch.prod(torch.tensor(grid)).item()) + patches_per_image.append(num_patches) + + patch_start = 0 + for i in range(num_images): + num_patches = patches_per_image[i] + item_pixel_values = pixel_values[patch_start:patch_start + num_patches] + item_grid_thw = image_grid_thw[i:i+1] + + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=item_pixel_values, + model_specific_data={ + "image_grid_thw": item_grid_thw, + }, + ) + mm_items.append(item) + patch_start += num_patches + else: item = MultimodalDataItem( modality=Modality.IMAGE, feature=pixel_values, model_specific_data=model_specific_data, ) mm_items.append(item) - except Exception as e: - logger.error(f"Failed to process image {url}: {e}") - return mm_items + + return mm_items, image_grid_thw, expanded_input_ids + + except Exception as e: + logger.exception(f"Failed to process images: {e}") + return [], None, [] def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> SGLSamplingParams: @@ -98,40 +176,142 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S return params +def _get_image_token_offsets( + input_ids: List[int], + image_token_id: Optional[int], + vision_start_id: Optional[int] = None, + vision_end_id: Optional[int] = None, +) -> List[tuple[int, int]]: + offsets = [] + + if vision_start_id is not None and vision_end_id is not None: + start_indices = [i for i, tok in enumerate(input_ids) if tok == vision_start_id] + end_indices = [i for i, tok in enumerate(input_ids) if tok == vision_end_id] + + for start, end in zip(start_indices, end_indices): + if start < end: + offsets.append((start + 1, end - 1)) + elif image_token_id is not None: + start = None + for i, tok in enumerate(input_ids): + if tok == image_token_id: + if start is None: + start = i + elif start is not None: + offsets.append((start, i - 1)) + start = None + if start is not None: + offsets.append((start, len(input_ids) - 1)) + + return offsets + + def transform_requests_to_sglang( old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None, processor: Optional[Any] = None, + hf_config: Optional[dict] = None, + tokenizer: Optional[Any] = None, ) -> List[Req]: - """Transforms Parallax Request to SGLang.Req format""" - reqs = [] + mm_config = hf_config or {} + for old_req in old_requests: sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( rid=old_req.request_id, origin_input_text="", - origin_input_ids=old_req.input_ids, + origin_input_ids=list(old_req.input_ids), sampling_params=sampling_params, lora_id=old_req.lora_id, ) + if old_req.multimodal_params is not None: - # Construct MultimodalInputs from dict try: if "mm_items" in old_req.multimodal_params: - # Case 1: Already structured data req.multimodal_inputs = MultimodalInputs.from_dict(old_req.multimodal_params) + elif "images" in old_req.multimodal_params and processor is not None: - # Case 2: List of image URLs, need processing image_urls = old_req.multimodal_params["images"] - mm_items = _process_images(image_urls, processor) + + input_text = "" + if tokenizer is not None: + try: + input_text = tokenizer.decode(old_req.input_ids, skip_special_tokens=False) + logger.debug(f"Decoded input text (length={len(input_text)}): {input_text[:100]}...") + except Exception as e: + logger.warning(f"Failed to decode input_ids: {e}") + + mm_items, image_grid_thw, expanded_input_ids = _process_images( + image_urls, + processor, + input_text=input_text, + mm_config=mm_config, + ) if mm_items: - req.multimodal_inputs = MultimodalInputs(mm_items=mm_items) - logger.debug(f"Successfully processed {len(mm_items)} images for request {req.rid}") + if expanded_input_ids and len(expanded_input_ids) > len(old_req.input_ids): + logger.debug( + f"Using expanded input_ids: {len(old_req.input_ids)} -> {len(expanded_input_ids)}" + ) + req.origin_input_ids = expanded_input_ids + input_ids_for_offsets = expanded_input_ids + else: + input_ids_for_offsets = list(old_req.input_ids) + + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + vision_end_id = mm_config.get("vision_end_token_id") + + offsets = _get_image_token_offsets( + input_ids_for_offsets, + image_token_id, + vision_start_id, + vision_end_id, + ) + + if len(offsets) == len(mm_items): + for item, offset in zip(mm_items, offsets): + item.offsets = [offset] + elif len(offsets) > 0: + for item in mm_items: + item.offsets = offsets + + for item in mm_items: + item.set_pad_value() + + combined_grid_thw = None + if image_grid_thw is not None: + combined_grid_thw = image_grid_thw + else: + grids = [] + for item in mm_items: + grid = item.model_specific_data.get("image_grid_thw") + if grid is not None: + grids.append(grid) + if grids: + combined_grid_thw = torch.cat(grids, dim=0) + + req.multimodal_inputs = prepare_sglang_multimodal_inputs( + mm_items=mm_items, + image_grid_thw=combined_grid_thw, + mm_config=mm_config, + input_ids=input_ids_for_offsets, + ) + + padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() + padded_input_ids = padding_pattern.pad_input_tokens( + input_ids_for_offsets, + req.multimodal_inputs + ) + req.origin_input_ids = padded_input_ids + + logger.debug( + f"Successfully processed {len(mm_items)} images for request {req.rid}, " + f"offsets={offsets}, padded_input_ids_len={len(padded_input_ids)}" + ) else: - # Fallback logger.warning( f"Assigning raw multimodal_params to req.multimodal_inputs. " f"SGLang might expect MultimodalInputs object with Tensors. " @@ -140,7 +320,7 @@ def transform_requests_to_sglang( req.multimodal_inputs = old_req.multimodal_params except Exception as e: - logger.warning(f"Failed to construct MultimodalInputs: {e}") + logger.exception(f"Failed to construct MultimodalInputs: {e}") req.multimodal_inputs = old_req.multimodal_params # Debug: Log before cache lookup @@ -174,10 +354,13 @@ def form_sgl_batch_prefill( model_runner: ModelRunner, page_tree_cache: Optional[PageRadixCache] = None, processor: Optional[Any] = None, + hf_config: Optional[dict] = None, + tokenizer: Optional[Any] = None, ) -> ForwardBatch: - """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache, processor) + sgl_reqs = transform_requests_to_sglang( + requests, page_tree_cache, processor, hf_config, tokenizer + ) def dummy_evict(*args): pass diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py new file mode 100644 index 00000000..508357ce --- /dev/null +++ b/src/parallax/sglang/multimodal_utils.py @@ -0,0 +1,146 @@ +""" +Multimodal utilities for SGLang backend. + +This module provides utilities for processing multimodal inputs in the Parallax +framework when using SGLang as the GPU backend. +""" + +from typing import Any, List, Optional, Tuple + +import torch +from sglang.srt.managers.schedule_batch import MultimodalInputs, MultimodalDataItem + + +def find_token_offsets(input_ids: List[int], token_id: int) -> List[Tuple[int, int]]: + offsets = [] + start = None + for i, tok in enumerate(input_ids): + if tok == token_id: + if start is None: + start = i + elif start is not None: + offsets.append((start, i - 1)) + start = None + if start is not None: + offsets.append((start, len(input_ids) - 1)) + return offsets + + +def find_token_offsets_by_pair( + input_ids: List[int], + start_token_id: int, + end_token_id: int, +) -> List[Tuple[int, int]]: + start_indices = [i for i, tok in enumerate(input_ids) if tok == start_token_id] + end_indices = [i for i, tok in enumerate(input_ids) if tok == end_token_id] + + offsets = [] + for start, end in zip(start_indices, end_indices): + if start < end: + # Content is between start+1 and end-1 + offsets.append((start + 1, end - 1)) + + return offsets + + +def compute_mrope_positions( + input_ids: List[int], + image_grid_thw: Optional[torch.Tensor], + mm_config: dict, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + + import logging + logger = logging.getLogger(__name__) + + seq_len = len(input_ids) + + def get_default_positions(): + positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(3, 1) + delta = torch.zeros(1, dtype=torch.long) + return positions, delta + + if image_grid_thw is None: + return get_default_positions() + + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + video_token_id = mm_config.get("video_token_id") + model_type = mm_config.get("model_type") + vision_config = mm_config.get("vision_config", {}) + + spatial_merge_size = ( + vision_config.get("spatial_merge_size") if isinstance(vision_config, dict) else None + ) + tokens_per_second = ( + vision_config.get("tokens_per_second") if isinstance(vision_config, dict) else None + ) + + if image_token_id is None or spatial_merge_size is None: + logger.debug( + f"Missing mrope config: image_token_id={image_token_id}, " + f"spatial_merge_size={spatial_merge_size}. Using default positions." + ) + return get_default_positions() + + try: + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + input_ids_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=spatial_merge_size, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_id, + model_type=model_type, + input_ids=input_ids_tensor, + image_grid_thw=image_grid_thw, + tokens_per_second=tokens_per_second, + ) + + mrope_positions = mrope_positions.squeeze(1) + return mrope_positions, mrope_position_delta + + except Exception as e: + logger.warning(f"Failed to compute mrope_positions: {e}. Using default positions.") + return get_default_positions() + + +def prepare_sglang_multimodal_inputs( + mm_items: List[MultimodalDataItem], + image_grid_thw: Optional[torch.Tensor] = None, + mm_config: Optional[dict] = None, + input_ids: Optional[List[int]] = None, +) -> MultimodalInputs: + mm_config = mm_config or {} + + # Extract token IDs from config + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + vision_end_id = mm_config.get("vision_end_token_id") + video_token_id = mm_config.get("video_token_id") + audio_token_id = mm_config.get("audio_token_id") + + + for item in mm_items: + if item.pad_value is None: + item.set_pad_value() + + mrope_positions = None + mrope_position_delta = None + + if input_ids is not None: + mrope_positions, mrope_position_delta = compute_mrope_positions( + input_ids, image_grid_thw, mm_config + ) + + return MultimodalInputs( + mm_items=mm_items, + im_token_id=image_token_id, + im_start_id=vision_start_id, + im_end_id=vision_end_id, + video_token_id=video_token_id, + audio_token_id=audio_token_id, + mrope_positions=mrope_positions, + mrope_position_delta=mrope_position_delta, + ) From adec10108621818d8a33f8ec458218cc21e9a759 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 3 Feb 2026 03:17:53 +0000 Subject: [PATCH 07/36] refactor sglang vlm --- src/parallax/sglang/batch_info.py | 246 +--------------------- src/parallax/sglang/multimodal_utils.py | 258 +++++++++++++++++++++--- 2 files changed, 239 insertions(+), 265 deletions(-) diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 06000507..1369180d 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -1,25 +1,12 @@ """ Store information about a SGLang batch. -The following is the flow of data structures for a batch in SGLang: - -ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ from types import SimpleNamespace from typing import List, Optional, Any -import requests -from io import BytesIO -from PIL import Image import torch -from sglang.srt.managers.schedule_batch import ( - Req, - ScheduleBatch, - MultimodalInputs, - MultimodalDataItem, - Modality, -) -from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch, MultimodalInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -33,130 +20,12 @@ from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) -from parallax.sglang.multimodal_utils import prepare_sglang_multimodal_inputs +from parallax.sglang.multimodal_utils import process_multimodal_request from parallax_utils.logging_config import get_logger logger = get_logger(__name__) -def _load_image(url: Any) -> Image.Image: - """Load a single image from URL, file path, or base64.""" - import base64 - - if isinstance(url, dict): - url = url.get("url") - if not isinstance(url, str): - raise ValueError(f"Unsupported image url type: {type(url)}") - - if url.startswith("http"): - response = requests.get(url, timeout=10) - response.raise_for_status() - image_data = BytesIO(response.content) - return Image.open(image_data).convert("RGB") - elif url.startswith("data:image"): - # Handle base64 encoded images - # Format: data:image/png;base64, - header, encoded = url.split(",", 1) - image_data = BytesIO(base64.b64decode(encoded)) - return Image.open(image_data).convert("RGB") - else: - return Image.open(url).convert("RGB") - - -def _process_images( - image_urls: List[Any], - processor: Any, - input_text: str = "", - mm_config: Optional[dict] = None, -) -> tuple[List[MultimodalDataItem], Optional[torch.Tensor], List[int]]: - if not image_urls: - return [], None, [] - - mm_items = [] - images = [] - - for url in image_urls: - try: - image = _load_image(url) - images.append(image) - except Exception as e: - logger.exception(f"Failed to load image {url}: {e}") - continue - - if not images: - return [], None, [] - - try: - inputs = processor(text=input_text, images=images, return_tensors="pt") - - if inputs is None: - logger.error("Processor returned None") - return [], None, [] - - pixel_values = inputs.get("pixel_values") - if pixel_values is None: - logger.error("Processor output missing pixel_values") - return [], None, [] - - expanded_input_ids = inputs.get("input_ids") - if expanded_input_ids is not None: - expanded_input_ids = expanded_input_ids.flatten().tolist() - else: - expanded_input_ids = [] - - logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") - - image_grid_thw = inputs.get("image_grid_thw") - image_sizes = inputs.get("image_sizes") - - model_specific_data = {} - if image_grid_thw is not None: - model_specific_data["image_grid_thw"] = image_grid_thw - if image_sizes is not None: - model_specific_data["image_sizes"] = image_sizes - - if image_grid_thw is not None and len(image_grid_thw) == len(images): - num_images = len(images) - - patches_per_image = [] - for i in range(num_images): - grid = image_grid_thw[i] - if isinstance(grid, torch.Tensor): - num_patches = int(torch.prod(grid).item()) - else: - num_patches = int(torch.prod(torch.tensor(grid)).item()) - patches_per_image.append(num_patches) - - patch_start = 0 - for i in range(num_images): - num_patches = patches_per_image[i] - item_pixel_values = pixel_values[patch_start:patch_start + num_patches] - item_grid_thw = image_grid_thw[i:i+1] - - item = MultimodalDataItem( - modality=Modality.IMAGE, - feature=item_pixel_values, - model_specific_data={ - "image_grid_thw": item_grid_thw, - }, - ) - mm_items.append(item) - patch_start += num_patches - else: - item = MultimodalDataItem( - modality=Modality.IMAGE, - feature=pixel_values, - model_specific_data=model_specific_data, - ) - mm_items.append(item) - - return mm_items, image_grid_thw, expanded_input_ids - - except Exception as e: - logger.exception(f"Failed to process images: {e}") - return [], None, [] - - def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> SGLSamplingParams: """Transforms Parallax SamplingParams to SGLang.SamplingParams format""" params = SGLSamplingParams( @@ -176,36 +45,6 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S return params -def _get_image_token_offsets( - input_ids: List[int], - image_token_id: Optional[int], - vision_start_id: Optional[int] = None, - vision_end_id: Optional[int] = None, -) -> List[tuple[int, int]]: - offsets = [] - - if vision_start_id is not None and vision_end_id is not None: - start_indices = [i for i, tok in enumerate(input_ids) if tok == vision_start_id] - end_indices = [i for i, tok in enumerate(input_ids) if tok == vision_end_id] - - for start, end in zip(start_indices, end_indices): - if start < end: - offsets.append((start + 1, end - 1)) - elif image_token_id is not None: - start = None - for i, tok in enumerate(input_ids): - if tok == image_token_id: - if start is None: - start = i - elif start is not None: - offsets.append((start, i - 1)) - start = None - if start is not None: - offsets.append((start, len(input_ids) - 1)) - - return offsets - - def transform_requests_to_sglang( old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None, @@ -233,83 +72,16 @@ def transform_requests_to_sglang( elif "images" in old_req.multimodal_params and processor is not None: image_urls = old_req.multimodal_params["images"] - - input_text = "" - if tokenizer is not None: - try: - input_text = tokenizer.decode(old_req.input_ids, skip_special_tokens=False) - logger.debug(f"Decoded input text (length={len(input_text)}): {input_text[:100]}...") - except Exception as e: - logger.warning(f"Failed to decode input_ids: {e}") - - mm_items, image_grid_thw, expanded_input_ids = _process_images( - image_urls, - processor, - input_text=input_text, + multimodal_inputs, padded_input_ids = process_multimodal_request( + image_urls=image_urls, + input_ids=old_req.input_ids, + processor=processor, + tokenizer=tokenizer, mm_config=mm_config, ) - - if mm_items: - if expanded_input_ids and len(expanded_input_ids) > len(old_req.input_ids): - logger.debug( - f"Using expanded input_ids: {len(old_req.input_ids)} -> {len(expanded_input_ids)}" - ) - req.origin_input_ids = expanded_input_ids - input_ids_for_offsets = expanded_input_ids - else: - input_ids_for_offsets = list(old_req.input_ids) - - image_token_id = mm_config.get("image_token_id") - vision_start_id = mm_config.get("vision_start_token_id") - vision_end_id = mm_config.get("vision_end_token_id") - - offsets = _get_image_token_offsets( - input_ids_for_offsets, - image_token_id, - vision_start_id, - vision_end_id, - ) - - if len(offsets) == len(mm_items): - for item, offset in zip(mm_items, offsets): - item.offsets = [offset] - elif len(offsets) > 0: - for item in mm_items: - item.offsets = offsets - - for item in mm_items: - item.set_pad_value() - - combined_grid_thw = None - if image_grid_thw is not None: - combined_grid_thw = image_grid_thw - else: - grids = [] - for item in mm_items: - grid = item.model_specific_data.get("image_grid_thw") - if grid is not None: - grids.append(grid) - if grids: - combined_grid_thw = torch.cat(grids, dim=0) - - req.multimodal_inputs = prepare_sglang_multimodal_inputs( - mm_items=mm_items, - image_grid_thw=combined_grid_thw, - mm_config=mm_config, - input_ids=input_ids_for_offsets, - ) - - padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() - padded_input_ids = padding_pattern.pad_input_tokens( - input_ids_for_offsets, - req.multimodal_inputs - ) + if multimodal_inputs is not None: + req.multimodal_inputs = multimodal_inputs req.origin_input_ids = padded_input_ids - - logger.debug( - f"Successfully processed {len(mm_items)} images for request {req.rid}, " - f"offsets={offsets}, padded_input_ids_len={len(padded_input_ids)}" - ) else: logger.warning( diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 508357ce..160e2d96 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -1,44 +1,161 @@ """ Multimodal utilities for SGLang backend. - -This module provides utilities for processing multimodal inputs in the Parallax -framework when using SGLang as the GPU backend. """ from typing import Any, List, Optional, Tuple +from io import BytesIO +import logging +import requests import torch -from sglang.srt.managers.schedule_batch import MultimodalInputs, MultimodalDataItem +from PIL import Image +from sglang.srt.managers.schedule_batch import MultimodalInputs, MultimodalDataItem, Modality +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +logger = logging.getLogger(__name__) -def find_token_offsets(input_ids: List[int], token_id: int) -> List[Tuple[int, int]]: - offsets = [] - start = None - for i, tok in enumerate(input_ids): - if tok == token_id: - if start is None: - start = i - elif start is not None: - offsets.append((start, i - 1)) - start = None - if start is not None: - offsets.append((start, len(input_ids) - 1)) - return offsets +def load_image(url: Any) -> Image.Image: + import base64 + + if isinstance(url, dict): + url = url.get("url") + if not isinstance(url, str): + raise ValueError(f"Unsupported image url type: {type(url)}") + + if url.startswith("http"): + response = requests.get(url, timeout=10) + response.raise_for_status() + image_data = BytesIO(response.content) + return Image.open(image_data).convert("RGB") + elif url.startswith("data:image"): + header, encoded = url.split(",", 1) + image_data = BytesIO(base64.b64decode(encoded)) + return Image.open(image_data).convert("RGB") + else: + return Image.open(url).convert("RGB") -def find_token_offsets_by_pair( - input_ids: List[int], - start_token_id: int, - end_token_id: int, -) -> List[Tuple[int, int]]: - start_indices = [i for i, tok in enumerate(input_ids) if tok == start_token_id] - end_indices = [i for i, tok in enumerate(input_ids) if tok == end_token_id] + +def process_images( + image_urls: List[Any], + processor: Any, + input_text: str = "", + mm_config: Optional[dict] = None, +) -> tuple[List[MultimodalDataItem], Optional[torch.Tensor], List[int]]: + if not image_urls: + return [], None, [] + + mm_items = [] + images = [] + + for url in image_urls: + try: + image = load_image(url) + images.append(image) + except Exception as e: + logger.exception(f"Failed to load image {url}: {e}") + continue + if not images: + return [], None, [] + + try: + inputs = processor(text=input_text, images=images, return_tensors="pt") + + if inputs is None: + logger.error("Processor returned None") + return [], None, [] + + pixel_values = inputs.get("pixel_values") + if pixel_values is None: + logger.error("Processor output missing pixel_values") + return [], None, [] + + expanded_input_ids = inputs.get("input_ids") + if expanded_input_ids is not None: + expanded_input_ids = expanded_input_ids.flatten().tolist() + else: + expanded_input_ids = [] + + logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") + + image_grid_thw = inputs.get("image_grid_thw") + image_sizes = inputs.get("image_sizes") + + model_specific_data = {} + if image_grid_thw is not None: + model_specific_data["image_grid_thw"] = image_grid_thw + if image_sizes is not None: + model_specific_data["image_sizes"] = image_sizes + + if image_grid_thw is not None and len(image_grid_thw) == len(images): + num_images = len(images) + + patches_per_image = [] + for i in range(num_images): + grid = image_grid_thw[i] + if isinstance(grid, torch.Tensor): + num_patches = int(torch.prod(grid).item()) + else: + num_patches = int(torch.prod(torch.tensor(grid)).item()) + patches_per_image.append(num_patches) + + patch_start = 0 + for i in range(num_images): + num_patches = patches_per_image[i] + item_pixel_values = pixel_values[patch_start:patch_start + num_patches] + item_grid_thw = image_grid_thw[i:i+1] + + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=item_pixel_values, + model_specific_data={ + "image_grid_thw": item_grid_thw, + }, + ) + mm_items.append(item) + patch_start += num_patches + else: + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_values, + model_specific_data=model_specific_data, + ) + mm_items.append(item) + + return mm_items, image_grid_thw, expanded_input_ids + + except Exception as e: + logger.exception(f"Failed to process images: {e}") + return [], None, [] + + +def get_image_token_offsets( + input_ids: List[int], + image_token_id: Optional[int], + vision_start_id: Optional[int] = None, + vision_end_id: Optional[int] = None, +) -> List[Tuple[int, int]]: offsets = [] - for start, end in zip(start_indices, end_indices): - if start < end: - # Content is between start+1 and end-1 - offsets.append((start + 1, end - 1)) + + if vision_start_id is not None and vision_end_id is not None: + start_indices = [i for i, tok in enumerate(input_ids) if tok == vision_start_id] + end_indices = [i for i, tok in enumerate(input_ids) if tok == vision_end_id] + + for start, end in zip(start_indices, end_indices): + if start < end: + offsets.append((start + 1, end - 1)) + elif image_token_id is not None: + start = None + for i, tok in enumerate(input_ids): + if tok == image_token_id: + if start is None: + start = i + elif start is not None: + offsets.append((start, i - 1)) + start = None + if start is not None: + offsets.append((start, len(input_ids) - 1)) return offsets @@ -144,3 +261,88 @@ def prepare_sglang_multimodal_inputs( mrope_positions=mrope_positions, mrope_position_delta=mrope_position_delta, ) + + +def process_multimodal_request( + image_urls: List[Any], + input_ids: List[int], + processor: Any, + tokenizer: Any, + mm_config: dict, +) -> Tuple[Optional[MultimodalInputs], List[int]]: + input_text = "" + if tokenizer is not None: + try: + input_text = tokenizer.decode(input_ids, skip_special_tokens=False) + logger.debug(f"Decoded input text (length={len(input_text)}): {input_text[:100]}...") + except Exception as e: + logger.warning(f"Failed to decode input_ids: {e}") + + mm_items, image_grid_thw, expanded_input_ids = process_images( + image_urls, + processor, + input_text=input_text, + mm_config=mm_config, + ) + + if not mm_items: + return None, list(input_ids) + + if expanded_input_ids and len(expanded_input_ids) > len(input_ids): + logger.debug(f"Using expanded input_ids: {len(input_ids)} -> {len(expanded_input_ids)}") + input_ids_for_offsets = expanded_input_ids + else: + input_ids_for_offsets = list(input_ids) + + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + vision_end_id = mm_config.get("vision_end_token_id") + + offsets = get_image_token_offsets( + input_ids_for_offsets, + image_token_id, + vision_start_id, + vision_end_id, + ) + + if len(offsets) == len(mm_items): + for item, offset in zip(mm_items, offsets): + item.offsets = [offset] + elif len(offsets) > 0: + for item in mm_items: + item.offsets = offsets + + for item in mm_items: + item.set_pad_value() + + combined_grid_thw = None + if image_grid_thw is not None: + combined_grid_thw = image_grid_thw + else: + grids = [] + for item in mm_items: + grid = item.model_specific_data.get("image_grid_thw") + if grid is not None: + grids.append(grid) + if grids: + combined_grid_thw = torch.cat(grids, dim=0) + + multimodal_inputs = prepare_sglang_multimodal_inputs( + mm_items=mm_items, + image_grid_thw=combined_grid_thw, + mm_config=mm_config, + input_ids=input_ids_for_offsets, + ) + + padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() + padded_input_ids = padding_pattern.pad_input_tokens( + input_ids_for_offsets, + multimodal_inputs + ) + + logger.debug( + f"Successfully processed {len(mm_items)} images, " + f"offsets={offsets}, padded_input_ids_len={len(padded_input_ids)}" + ) + + return multimodal_inputs, padded_input_ids From 9a3707abe2a6a85a70d7e759f27072fc1d1c401c Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 3 Feb 2026 16:52:54 +0800 Subject: [PATCH 08/36] tmp add mlx --- pyproject.toml | 3 + src/parallax/models/qwen3_vl.py | 167 ++++++++++ src/parallax/server/executor/base_executor.py | 203 ++++++++++-- src/parallax/server/executor/mlx_executor.py | 49 +++ src/parallax/server/model.py | 295 +++++++++++++++++- src/parallax/server/request.py | 89 +++++- src/parallax/server/shard_loader.py | 156 ++++++++- src/parallax/utils/vlm_utils.py | 282 +++++++++++++++++ 8 files changed, 1205 insertions(+), 39 deletions(-) create mode 100644 src/parallax/models/qwen3_vl.py create mode 100644 src/parallax/utils/vlm_utils.py diff --git a/pyproject.toml b/pyproject.toml index 999d1f54..52bf2534 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,18 +47,21 @@ mac = [ "torch==2.8.0", "mlx-lm==0.30.5", "mlx==0.30.4", + "mlx-vlm==0.3.10", ] gpu = [ "sglang[all]==0.5.7", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", + "mlx-vlm==0.3.10", ] vllm = [ "vllm==0.14.0", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", + "mlx-vlm==0.3.10", ] benchmark = [ diff --git a/src/parallax/models/qwen3_vl.py b/src/parallax/models/qwen3_vl.py new file mode 100644 index 00000000..0ee4ff87 --- /dev/null +++ b/src/parallax/models/qwen3_vl.py @@ -0,0 +1,167 @@ +""" +Defines the Qwen3VL model for Parallax. + +This module reuses components from mlx-vlm and adds PagedAttention support +for distributed inference. +""" + +from typing import Any, List, Optional + +import mlx.core as mx +from mlx import nn + +from parallax.server.cache.base import BaseCache +from parallax_extensions.ops import paged_attention_v1, reshape_and_cache +from parallax_utils.logging_config import get_logger + +# Import from mlx-vlm +from mlx_vlm.models.qwen3_vl.language import ( + Attention as MLXQwen3VLAttention, + MLP, + Qwen3VLDecoderLayer as MLXQwen3VLDecoderLayer, + Qwen3VLRotaryEmbedding, + apply_multimodal_rotary_pos_emb, +) + +logger = get_logger(__name__) + + +class ParallaxQwen3VLAttention(MLXQwen3VLAttention): + """Qwen3VL Attention with PagedAttention support for Parallax.""" + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[BaseCache] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + prefix_lens: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = self.q_norm( + queries.reshape(B, L, self.n_heads, self.head_dim) + ).transpose(0, 2, 1, 3) + keys = self.k_norm( + keys.reshape(B, L, self.n_kv_heads, self.head_dim) + ).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, self.head_dim) + + # Get KV cache + key_cache_global, value_cache_global = cache.get_cache() + + # Compute RoPE position + if L == 1: + # Decode phase: use context_lengths - 1 as offset + current_pos = context_lengths - 1 + pos_ids = mx.broadcast_to(current_pos[:, None], (B, L)) + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + elif position_ids is not None: + # Prefill with MRoPE position_ids + pos_ids = position_ids + elif prefix_lens is not None: + # Prefill with prefix cache + pos_ids = mx.arange(L)[None, :] + prefix_lens[:, None] + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + else: + # Standard prefill + pos_ids = mx.arange(L)[None, :] + pos_ids = mx.broadcast_to(pos_ids, (B, L)) + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + + cos, sin = self.rotary_emb(values, pos_ids) + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + # Cache update with PagedAttention + block_size = key_cache_global.shape[3] + reshape_and_cache( + keys.transpose(0, 2, 1, 3), + values, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, + ) + + # Compute attention + if L == 1: + # Decode: use PagedAttention + output = paged_attention_v1( + queries, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + self.n_kv_heads, + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + else: + # Prefill: standard attention + from mlx_lm.models.base import scaled_dot_product_attention + output = scaled_dot_product_attention( + queries, + keys, + values.transpose(0, 2, 1, 3), + scale=self.scale, + mask=mask, + cache=None, + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class ParallaxQwen3VLBlock(nn.Module): + """Qwen3VL Transformer block with PagedAttention support.""" + + def __init__(self, args, layer_idx: int, local_layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = ParallaxQwen3VLAttention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "Qwen3VLForConditionalGeneration" + + +EntryClass = ParallaxQwen3VLBlock diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index d0511089..68b48201 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -39,11 +39,13 @@ IntermediateRequest, Request, RequestStatus, + VLMInputs, ) from parallax.server.sampling.sampling_params import SamplingParams from parallax.server.scheduler import Scheduler from parallax.utils.shared_state import SharedState from parallax.utils.utils import get_current_device, get_device_dtype, get_zmq_socket +from parallax.utils.vlm_utils import create_vlm_inputs_from_request from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -642,12 +644,9 @@ def shutdown(self): logger.debug("Executor shutdown complete.") - def _handle_raw_request(self, raw_request: Dict): - assert "messages" in raw_request, "Request did not contain messages" - - rid = raw_request["rid"] + def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> list: + """Process a text-only request using the tokenizer.""" if self.tokenizer.chat_template: - messages = raw_request["messages"] has_non_text_content = any( isinstance(msg.get("content"), list) and any( @@ -659,7 +658,6 @@ def _handle_raw_request(self, raw_request: Dict): if not has_non_text_content: process_message_content(messages) chat_template_kwargs = raw_request.get("chat_template_kwargs", {}) - # check extra_body for backward compatibility if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]: chat_template_kwargs.update(raw_request["extra_body"]["chat_template_kwargs"]) @@ -671,8 +669,186 @@ def _handle_raw_request(self, raw_request: Dict): **chat_template_kwargs, ) else: - prompt = convert_chat(raw_request["messages"], raw_request.get("role_mapping")) + prompt = convert_chat(messages, raw_request.get("role_mapping")) prompt = self.tokenizer.encode(prompt) + + return prompt + + def _process_vlm_request(self, rid: str, messages: list, image_urls: list): + """ + Process a VLM (multimodal) request using the VLM processor. + + The processor handles both text formatting and image preprocessing together, + ensuring proper image token insertion and expansion. + + Returns: + Tuple of (input_ids, VLMInputs) + """ + from parallax.utils.vlm_utils import load_image + + try: + # Load images + images = [] + for url in image_urls: + try: + img = load_image(url) + images.append(img) + except Exception as e: + logger.warning(f"Failed to load image {url}: {e}") + + if not images: + logger.warning(f"No images loaded for VLM request {rid}, falling back to text processing") + return self._process_text_request(rid, messages, {}), None + + # Format messages for processor + # Most VLM processors expect a specific format with image placeholders + formatted_messages = self._format_messages_for_vlm(messages) + + # Apply chat template to get the text prompt + if hasattr(self.vlm_processor, 'apply_chat_template'): + # Some processors have their own chat template + text_prompt = self.vlm_processor.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + elif self.tokenizer.chat_template: + # Fall back to tokenizer's chat template + text_prompt = self.tokenizer.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + # Simple fallback + text_prompt = "\n".join( + f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" + for msg in formatted_messages + ) + + # Use processor to handle text + images together + processor_inputs = self.vlm_processor( + text=text_prompt, + images=images, + return_tensors="np", + ) + + # Extract input_ids + input_ids = processor_inputs.get("input_ids") + if input_ids is None: + raise ValueError("Processor did not return input_ids") + + if hasattr(input_ids, "numpy"): + input_ids = input_ids.numpy() + prompt = input_ids.flatten().tolist() + + # Extract pixel_values and other vision inputs + pixel_values = processor_inputs.get("pixel_values") + if pixel_values is not None and hasattr(pixel_values, "numpy"): + pixel_values = pixel_values.numpy() + + image_grid_thw = processor_inputs.get("image_grid_thw") + if image_grid_thw is not None and hasattr(image_grid_thw, "numpy"): + image_grid_thw = image_grid_thw.numpy() + + # Create VLMInputs + vlm_inputs = VLMInputs( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_sizes=[(img.height, img.width) for img in images], + images_processed=False, + ) + + logger.debug( + f"VLM request {rid}: {len(images)} images, " + f"input_ids length={len(prompt)}, " + f"pixel_values shape={pixel_values.shape if pixel_values is not None else None}" + ) + + return prompt, vlm_inputs + + except Exception as e: + logger.error(f"Failed to process VLM request {rid}: {e}") + # Fall back to text-only processing + return self._process_text_request(rid, messages, {}), None + + def _format_messages_for_vlm(self, messages: list) -> list: + """ + Format messages for VLM processing. + Converts image_url content parts to a format the processor understands. + """ + formatted = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + # Convert content list to text with image placeholders + text_parts = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + text_parts.append(part.get("text", "")) + elif part.get("type") == "image_url": + # Add image placeholder - processor will handle the actual insertion + # Different models use different placeholders + text_parts.append("") + elif isinstance(part, str): + text_parts.append(part) + formatted.append({ + "role": msg.get("role", "user"), + "content": "".join(text_parts), + }) + else: + formatted.append(msg) + return formatted + + def _extract_text_from_content(self, content) -> str: + """Extract text from message content (handles both string and list formats).""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + texts.append(part.get("text", "")) + elif isinstance(part, str): + texts.append(part) + return " ".join(texts) + return str(content) + + def _handle_raw_request(self, raw_request: Dict): + assert "messages" in raw_request, "Request did not contain messages" + + rid = raw_request["rid"] + messages = raw_request["messages"] + + # Extract image URLs first to determine if this is a multimodal request + multimodal_params = None + image_urls = [] + for message in messages: + content = message.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + if multimodal_params is None: + multimodal_params = {"images": []} + image_url = part["image_url"] + multimodal_params["images"].append(image_url) + # Extract URL string for processing + if isinstance(image_url, dict): + image_urls.append(image_url.get("url", image_url)) + else: + image_urls.append(image_url) + + # Process multimodal request with VLM processor + vlm_inputs = None + has_vlm_processor = hasattr(self, 'vlm_processor') and self.vlm_processor is not None + + if image_urls and has_vlm_processor: + # Use VLM processor to handle both text and images together + prompt, vlm_inputs = self._process_vlm_request(rid, messages, image_urls) + else: + # Standard text-only processing + prompt = self._process_text_request(rid, messages, raw_request) max_seq_len = self.max_sequence_length if self.max_sequence_length is not None else 4096 max_seq_len = max(max_seq_len, 4096) @@ -700,18 +876,6 @@ def _handle_raw_request(self, raw_request: Dict): lora_path = raw_request.get("lora_path") return_probs = raw_request.get("return_probs", False) # Get return_probs parameter - # Extract multimodal params if present - multimodal_params = None - if "messages" in raw_request: - for message in raw_request["messages"]: - content = message.get("content") - if isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") == "image_url": - if multimodal_params is None: - multimodal_params = {"images": []} - multimodal_params["images"].append(part["image_url"]) - raw_sampling_params = raw_request.get("sampling_params") if raw_sampling_params is None: sampling_params = SamplingParams() @@ -737,6 +901,7 @@ def _handle_raw_request(self, raw_request: Dict): lora_path=lora_path, return_probs=return_probs, multimodal_params=multimodal_params, + vlm_inputs=vlm_inputs, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 268e035c..3ce6d82b 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -124,6 +124,20 @@ def __init__( logger.debug( f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={self.config.get('num_hidden_layers')}" ) + + # Load VLM processor if this is a VLM model (first peer only) + self.vlm_processor = None + self.model_type = self.config.get("model_type") + if hasattr(self.model_shard, 'is_vlm') and self.model_shard.is_vlm and start_layer == 0: + try: + from transformers import AutoProcessor + self.vlm_processor = AutoProcessor.from_pretrained( + model_repo, + trust_remote_code=True + ) + logger.info(f"Loaded VLM processor for {self.model_type}") + except Exception as e: + logger.warning(f"Failed to load VLM processor: {e}. VLM image processing will be disabled.") # TODO: Duplicate code to BaseExecutor since num_shard_layers and dtype are needed for initializing kv cache self.num_shard_layers = end_layer - start_layer @@ -392,6 +406,9 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: slot_mapping=prepared_inputs.get("slot_mapping"), state_slot_mapping=prepared_inputs.get("state_slot_mapping"), prefix_lens=prepared_inputs.get("prefix_lens"), # For RoPE offset in prefix cache + # VLM inputs (only present for first peer with VLM model) + pixel_values=prepared_inputs.get("pixel_values"), + image_grid_thw=prepared_inputs.get("image_grid_thw"), ) mx.eval(hidden_states) @@ -658,6 +675,38 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A "prefix_lens": prefix_lens_tensor, # For RoPE offset calculation "actual_processed_lengths": actual_processed_lengths_tensor, # For correct logit selection } + + # VLM support: collect pixel_values and image metadata for first peer + if self.is_first_peer and hasattr(self.model_shard, 'is_vlm') and self.model_shard.is_vlm: + pixel_values_list = [] + image_grid_thw_list = [] + has_vlm_inputs = False + + for req in batched_requests: + if req.vlm_inputs is not None and req.vlm_inputs.has_images(): + has_vlm_inputs = True + pixel_values_list.append(req.vlm_inputs.pixel_values) + if req.vlm_inputs.image_grid_thw is not None: + image_grid_thw_list.append(req.vlm_inputs.image_grid_thw) + else: + pixel_values_list.append(None) + + if has_vlm_inputs: + # For now, we only support single-image batching where all requests + # in the batch either have images or don't have images + # TODO: Support mixed batches with proper padding + non_none_pixels = [p for p in pixel_values_list if p is not None] + if len(non_none_pixels) > 0: + # Concatenate all pixel values along batch dimension + ret["pixel_values"] = mx.concatenate( + [mx.array(p) for p in non_none_pixels], axis=0 + ) + if image_grid_thw_list: + ret["image_grid_thw"] = mx.concatenate( + [mx.array(g) for g in image_grid_thw_list], axis=0 + ) + logger.debug(f"VLM batch: pixel_values shape={ret['pixel_values'].shape}") + logger.debug(f"Prepared MLX prefill batch (size={batch_size})") return ret diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index 54855c94..4ca32b8b 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -2,9 +2,11 @@ Defines the ShardedModel class for distributing MLX models across multiple devices. """ -from typing import Any, List, Optional, Type +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import mlx.core as mx +import numpy as np from mlx import nn from mlx_lm.models.base import BaseModelArgs @@ -14,12 +16,54 @@ logger = get_logger(__name__) +class VisionConfig: + """Dynamic configuration for vision models in VLM. + + This class dynamically accepts all parameters from config.json's vision_config, + making it compatible with different VLM architectures (Qwen-VL, LLaVA, etc.). + """ + + def __init__(self, **kwargs): + # Set all provided parameters as attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + # Set common defaults if not provided + if not hasattr(self, 'model_type'): + self.model_type = "clip_vision_model" + if not hasattr(self, 'hidden_size'): + self.hidden_size = 1024 + + @classmethod + def from_dict(cls, params: Dict[str, Any]) -> "VisionConfig": + """Create VisionConfig from a dictionary, accepting all parameters.""" + if params is None: + return None + return cls(**params) + + +@dataclass +class InputEmbeddingsOutput: + """Output from get_input_embeddings method.""" + inputs_embeds: mx.array + attention_mask: Optional[mx.array] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "inputs_embeds": self.inputs_embeds, + "attention_mask": self.attention_mask, + } + + class ShardedModel(nn.Module): """A general class for MLX sharded model, adapted for Parallax KV cache. Assumes self.layers are composed of modules (e.g., TransformerBlocks) that internally use ParallaxAttention and their __call__ method returns (hidden_state, new_k_for_layer, new_v_for_layer). + + Supports VLM (Vision Language Models) by optionally loading vision_tower + and multi_modal_projector on the first shard. """ def __init__( @@ -32,6 +76,13 @@ def __init__( *, has_norm_in: bool = False, dtype: Optional[mx.Dtype] = None, + # VLM support + vision_config: Optional[Dict[str, Any]] = None, + vision_tower_class: Optional[Type[nn.Module]] = None, + multi_modal_projector_class: Optional[Type[nn.Module]] = None, + image_token_index: Optional[int] = None, + vision_feature_layer: Optional[int] = -2, + vision_feature_select_strategy: str = "default", ): super().__init__() self.config = config @@ -48,14 +99,38 @@ def __init__( self.is_first_shard = start_layer == 0 self.is_last_shard = end_layer == config.num_hidden_layers self.n_layers_in_shard = end_layer - start_layer + + # VLM configuration + self.is_vlm = vision_config is not None and vision_tower_class is not None + self.vision_config = VisionConfig.from_dict(vision_config) if vision_config else None + self.image_token_index = image_token_index + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy if self.is_first_shard: self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) if has_norm_in: self.norm_in = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + # Initialize vision components for VLM on first shard + if self.is_vlm: + logger.info(f"Initializing VLM components: vision_tower ({self.vision_config.model_type})") + self.vision_tower = vision_tower_class(self.vision_config) + # Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel + # In these cases, multi_modal_projector_class can be None + if multi_modal_projector_class is not None: + self.multi_modal_projector = multi_modal_projector_class(config) + else: + self.multi_modal_projector = None + logger.info("No separate projector class - projector is integrated into VisionModel") + else: + self.vision_tower = None + self.multi_modal_projector = None else: self.embed_tokens = None self.norm_in = None + self.vision_tower = None + self.multi_modal_projector = None self.layers = [ block_class(config, layer_idx, layer_idx - start_layer) @@ -68,7 +143,7 @@ def __init__( else: self.norm = None self.lm_head = None - + def shard_layers(self): group = mx.distributed.init() tp_size = group.size() @@ -81,6 +156,188 @@ def shard_layers(self): f"Model {layer.__class__.__name__} does not have a shard method, does not support tensor parallelism" ) exit(1) + + def get_input_embeddings( + self, + input_ids: mx.array, + pixel_values: Optional[mx.array] = None, + **kwargs, + ) -> InputEmbeddingsOutput: + """Get input embeddings, optionally with vision features merged in. + + This method handles: + 1. Text-only inputs: Simply embed tokens + 2. VLM inputs: Embed tokens, encode images, and merge vision features + + Args: + input_ids: (batch, seq_len) Token IDs + pixel_values: (batch, C, H, W) or (num_patches, C, H, W) Image pixel values + **kwargs: Additional arguments (e.g., image_grid_thw for Qwen2-VL) + + Returns: + InputEmbeddingsOutput with merged embeddings + """ + if not self.is_first_shard: + raise ValueError("get_input_embeddings should only be called on the first shard") + + # Get text embeddings + inputs_embeds = self.embed_tokens(input_ids) + + # If no images or not a VLM, return text embeddings directly + if pixel_values is None or not self.is_vlm: + return InputEmbeddingsOutput(inputs_embeds=inputs_embeds) + + # Process vision features + image_features = self._encode_images(pixel_values, **kwargs) + + # Merge image features with text embeddings + final_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + + return InputEmbeddingsOutput(inputs_embeds=final_embeds) + + def _encode_images( + self, + pixel_values: mx.array, + **kwargs, + ) -> mx.array: + """Encode images through vision tower and projector. + + Args: + pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W) + **kwargs: Additional model-specific arguments + + Returns: + Projected image features ready to be merged with text embeddings + """ + if self.vision_tower is None: + raise ValueError("Vision tower not initialized for this model") + + # Convert to vision tower expected format (typically NHWC for MLX) + # pixel_values is usually in NCHW format from processor + if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: + # NCHW -> NHWC + pixel_values = pixel_values.transpose(0, 2, 3, 1) + + # Ensure correct dtype + if hasattr(self.vision_tower, 'patch_embed') and hasattr(self.vision_tower.patch_embed, 'proj'): + target_dtype = self.vision_tower.patch_embed.proj.weight.dtype + pixel_values = pixel_values.astype(target_dtype) + else: + pixel_values = pixel_values.astype(self.dtype) + + # Get vision features from vision tower + # Different vision models have different output formats + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + + # Handle different output formats + if isinstance(vision_outputs, tuple): + # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states) + if len(vision_outputs) >= 3: + hidden_states = vision_outputs[2] # All hidden states + if isinstance(self.vision_feature_layer, int): + selected_features = hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + # Remove CLS token + selected_features = selected_features[:, 1:] + else: + # Multiple layers + hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_features = mx.concatenate(hs_pool, axis=-1) + else: + # Simple (pooler, hidden_state) output + selected_features = vision_outputs[1] + if self.vision_feature_select_strategy == "default": + selected_features = selected_features[:, 1:] + else: + # Direct hidden state output (Qwen-VL style) + # These models already output projected features + selected_features = vision_outputs + + # Project to language model dimension if projector exists + # Some VLMs (e.g., Qwen-VL) have projection built into VisionModel + if self.multi_modal_projector is not None: + image_features = self.multi_modal_projector(selected_features) + else: + # VisionModel already outputs projected features + image_features = selected_features + + return image_features + + def _merge_input_ids_with_image_features( + self, + image_features: mx.array, + inputs_embeds: mx.array, + input_ids: mx.array, + ) -> mx.array: + """Merge image features into input embeddings at image token positions. + + This replaces placeholder tokens with actual image feature embeddings. + + Args: + image_features: (num_images, num_patches, hidden_dim) or (total_patches, hidden_dim) + inputs_embeds: (batch, seq_len, hidden_dim) Text embeddings + input_ids: (batch, seq_len) Token IDs for finding image positions + + Returns: + Merged embeddings with image features inserted at image token positions + """ + if self.image_token_index is None: + logger.warning("image_token_index not set, cannot merge image features") + return inputs_embeds + + batch_size, seq_len, hidden_dim = inputs_embeds.shape + + # Find positions of image tokens + image_positions = (input_ids == self.image_token_index) + + # Flatten image features if needed + if image_features.ndim == 3: + # (num_images, num_patches, dim) -> (total_patches, dim) + image_features = image_features.reshape(-1, image_features.shape[-1]) + + # Cast image features to match embedding dtype + image_features = image_features.astype(inputs_embeds.dtype) + + # Process each batch item + batch_outputs = [] + feature_start_idx = 0 + + for batch_idx in range(batch_size): + batch_mask = image_positions[batch_idx] + num_positions = int(mx.sum(batch_mask).item()) + + if num_positions > 0: + # Extract features for this batch + batch_features = image_features[feature_start_idx:feature_start_idx + num_positions] + + if batch_features.shape[0] != num_positions: + raise ValueError( + f"Number of image token positions ({num_positions}) does not match " + f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}" + ) + + # Create indices for gathering + cumsum = mx.cumsum(batch_mask.astype(mx.int32)) + feature_indices = mx.where(batch_mask, cumsum - 1, 0) + + # Gather features and create merged output + gathered_features = batch_features[feature_indices] + batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1) + batch_output = mx.where( + batch_mask_expanded, gathered_features, inputs_embeds[batch_idx] + ) + + feature_start_idx += num_positions + else: + batch_output = inputs_embeds[batch_idx] + + batch_outputs.append(batch_output) + + return mx.stack(batch_outputs, axis=0) def logits_to_tokens( self, @@ -128,29 +385,53 @@ def __call__( block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, **kwargs, ) -> mx.array: """ + Forward pass through the sharded model. + Args: h_or_tokens: (batch, target_len_padded, D) or (batch, target_len_padded) for prefill, (batch, 1, D) or (batch, 1) for decode. cache: List of layer caches (KVCache or LinearCache). Legacy mode: (key_cache_global, value_cache_global) tuple. - lengths: (batch,) true lengths of each sequence in batch. mask: Optional causal mask for the current segment. - window_size: Optional int, if provided, will use a sliding window attention mask. block_tables: (batch, max_blocks) for PagedAttention. context_lengths: (batch,) for PagedAttention. slot_mapping: (total_tokens,) for PagedAttention. + pixel_values: (batch, C, H, W) or (num_patches, C, H, W) for VLM. + Image pixel values to be processed by vision tower. + inputs_embeds: (batch, seq_len, hidden_dim) Pre-computed embeddings. + If provided, skips embedding and vision processing. + **kwargs: Additional model-specific arguments (e.g., image_grid_thw). + + Returns: + For last shard: logits (batch, seq_len, vocab_size) + For other shards: hidden states (batch, seq_len, hidden_dim) """ h = h_or_tokens target_len = h.shape[1] if self.is_first_shard: - if self.embed_tokens is None: - raise ValueError("embed_tokens is None for the first shard.") - h = self.embed_tokens(h) + if inputs_embeds is not None: + # Use pre-computed embeddings (e.g., from get_input_embeddings) + h = inputs_embeds + else: + if self.embed_tokens is None: + raise ValueError("embed_tokens is None for the first shard.") + + # Check if we need to process vision inputs + if pixel_values is not None and self.is_vlm: + # Use get_input_embeddings for VLM processing + embed_output = self.get_input_embeddings(h, pixel_values, **kwargs) + h = embed_output.inputs_embeds + else: + # Standard text embedding + h = self.embed_tokens(h) + if self.has_norm_in and self.norm_in: h = self.norm_in(h) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 5354dae6..4735ee07 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -59,8 +59,11 @@ """ import uuid +from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union + +import numpy as np from parallax.server.sampling.sampling_params import SamplingParams from parallax_utils.logging_config import get_logger @@ -68,6 +71,63 @@ logger = get_logger(__name__) +@dataclass +class VLMInputs: + """Container for Vision Language Model inputs. + + This is used to pass image/video data through the pipeline. + Only the first peer needs to process pixel_values; subsequent peers + receive pre-computed image embeddings merged into hidden_states. + """ + # Preprocessed image tensor, shape varies by model: + # - LLaVA: (num_images, C, H, W) or (num_patches, C, patch_H, patch_W) + # - Qwen-VL: (num_patches, C, patch_H, patch_W) with temporal dim for video + pixel_values: Optional[np.ndarray] = None + + # For models with dynamic resolution (e.g., Qwen2-VL): + # Tuple of (temporal, height, width) grid sizes for each image + # Shape: (num_images, 3) where each row is (t, h, w) + image_grid_thw: Optional[np.ndarray] = None + + # Number of image tokens per image (for variable-length image tokens) + image_token_counts: Optional[List[int]] = None + + # Original image sizes before preprocessing (height, width) + # Useful for models that need aspect ratio information + image_sizes: Optional[List[tuple]] = None + + # Whether images have been processed into embeddings + # (set to True after first peer processes images) + images_processed: bool = False + + def has_images(self) -> bool: + """Check if this request contains image inputs.""" + return self.pixel_values is not None and len(self.pixel_values) > 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "pixel_values": self.pixel_values, + "image_grid_thw": self.image_grid_thw, + "image_token_counts": self.image_token_counts, + "image_sizes": self.image_sizes, + "images_processed": self.images_processed, + } + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["VLMInputs"]: + """Create from dictionary.""" + if data is None: + return None + return cls( + pixel_values=data.get("pixel_values"), + image_grid_thw=data.get("image_grid_thw"), + image_token_counts=data.get("image_token_counts"), + image_sizes=data.get("image_sizes"), + images_processed=data.get("images_processed", False), + ) + + class RequestStatus(Enum): """Enumeration of possible request statuses for the First Peer.""" @@ -97,6 +157,7 @@ def __init__( sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, multimodal_params: Optional[Dict] = None, + vlm_inputs: Optional[VLMInputs] = None, ): self.request_id = request_id or str(uuid.uuid4()) self.status = status @@ -111,6 +172,14 @@ def __init__( self.lora_id: Optional[str] = None self.lora_path = lora_path self.multimodal_params = multimodal_params + + # VLM support: structured container for vision inputs + self.vlm_inputs = vlm_inputs + + @property + def has_images(self) -> bool: + """Check if this request contains image inputs.""" + return self.vlm_inputs is not None and self.vlm_inputs.has_images() @property def is_finished(self) -> bool: @@ -164,6 +233,7 @@ def __init__( lora_path: Optional[str] = None, return_probs: bool = False, multimodal_params: Optional[Dict] = None, + vlm_inputs: Optional[VLMInputs] = None, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -175,6 +245,7 @@ def __init__( sampling_params=sampling_params, lora_path=lora_path, multimodal_params=multimodal_params, + vlm_inputs=vlm_inputs, ) self.prompt = prompt self.return_probs = return_probs @@ -272,6 +343,7 @@ def __init__( lora_path: Optional[str] = None, token_prob: Optional[float] = None, return_probs: bool = False, + vlm_inputs: Optional[VLMInputs] = None, ): super().__init__( request_id=request_id, @@ -280,6 +352,7 @@ def __init__( input_ids=input_ids, sampling_params=sampling_params, lora_path=lora_path, + vlm_inputs=vlm_inputs, ) # Hidden states from the previous peer's computation. # Shape: @@ -336,6 +409,18 @@ def from_initial_request( else: next_token_id = initial_request.output_ids[-1] + # For VLM: after first peer processes images, mark as processed + # and don't pass pixel_values to subsequent peers (only metadata) + vlm_inputs = None + if initial_request.vlm_inputs is not None: + vlm_inputs = VLMInputs( + pixel_values=None, # Don't pass raw pixels to next peers + image_grid_thw=initial_request.vlm_inputs.image_grid_thw, + image_token_counts=initial_request.vlm_inputs.image_token_counts, + image_sizes=initial_request.vlm_inputs.image_sizes, + images_processed=True, # Mark as processed by first peer + ) + return IntermediateRequest( request_id=initial_request.request_id, status=initial_request.status, @@ -348,6 +433,7 @@ def from_initial_request( lora_path=lora_path, token_prob=token_prob, return_probs=initial_request.return_probs, + vlm_inputs=vlm_inputs, ) @classmethod @@ -374,6 +460,7 @@ def from_intermediate_request( lora_path=lora_path, token_prob=token_prob, return_probs=old_request.return_probs, + vlm_inputs=old_request.vlm_inputs, # Pass through VLM metadata ) def __repr__(self): diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index f8793588..8854ac7b 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -7,7 +7,7 @@ import json import pathlib import types -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Type import mlx.core as mx from huggingface_hub import snapshot_download @@ -28,6 +28,63 @@ "kimi_k2": "mlx_lm.models.deepseek_v3", } +# VLM models that need to use text_config for ModelArgs +# Format: model_type -> base_model_type (for loading ModelArgs from mlx_lm) +VLM_TEXT_CONFIG_MAP = { + "qwen3_vl": "qwen3", + "qwen2_vl": "qwen2", + "qwen2_5_vl": "qwen2", +} + +# VLM models that need special handling (have separate projector class) +# Format: model_type -> (projector_module_path, projector_class_name) +# Default: VisionModel from mlx_vlm.models.{model_type}, no separate projector +VLM_SPECIAL_PROJECTOR_MAP = { + "llava": ("mlx_vlm.models.llava.llava", "LlavaMultiModalProjector"), + "llava_next": ("mlx_vlm.models.llava_next.llava_next", "LlavaMultiModalProjector"), +} + + +def _get_vlm_classes( + model_type: str, config: Dict[str, Any] +) -> Tuple[Optional[Type[nn.Module]], Optional[Type[nn.Module]], Optional[Dict[str, Any]]]: + """ + Get VLM-specific classes for a given model type. + + Args: + model_type: The model type from config.json + config: Full model configuration + + Returns: + Tuple of (vision_tower_class, projector_class, vision_config) + Returns (None, None, None) if not a VLM model + """ + vision_config = config.get("vision_config") + if vision_config is None: + return None, None, None + + try: + # Default: load VisionModel from mlx_vlm.models.{model_type} + vision_module_path = f"mlx_vlm.models.{model_type}" + vision_module = importlib.import_module(vision_module_path) + vision_tower_class = getattr(vision_module, "VisionModel") + + # Check if this model needs a separate projector + projector_class = None + if model_type in VLM_SPECIAL_PROJECTOR_MAP: + proj_module_path, proj_class_name = VLM_SPECIAL_PROJECTOR_MAP[model_type] + proj_module = importlib.import_module(proj_module_path) + projector_class = getattr(proj_module, proj_class_name) + logger.info(f"Loaded VLM classes for {model_type}: VisionModel + {proj_class_name}") + else: + logger.info(f"Loaded VLM classes for {model_type}: VisionModel (projector integrated)") + + return vision_tower_class, projector_class, vision_config + + except (ImportError, AttributeError) as e: + logger.warning(f"Failed to load VLM classes for {model_type}: {e}") + return None, None, None + class MLXModelLoader: """ @@ -85,10 +142,13 @@ def register_block_class(self): if hasattr(entry_class, "get_architecture"): architecture = entry_class.get_architecture() self.block_class_map[architecture] = entry_class - # logger.info(f"Registered {architecture} -> {entry_class.__name__}") + logger.debug(f"Registered {architecture} -> {entry_class.__name__}") else: logger.warning(f"No architecture attribute found in {entry_class.__name__}") + except ImportError as e: + # Log more details for import errors (often missing dependencies) + logger.warning(f"Failed to import {model_file.name}: {e} (install required dependencies)") except Exception as e: logger.warning(f"Failed to load model from {model_file}: {e}") @@ -253,28 +313,44 @@ def load( if block_class is None: raise ValueError(f"block_class not found for architecture: {architecture}") - num_hidden_layers = config.get("num_hidden_layers", 0) - current_start_layer = self.start_layer if self.start_layer is not None else 0 - current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers - # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") if not model_type: raise ValueError("model_type not found in config.json") - if model_type in MODEL_CLASS_MAP: - model_class = MODEL_CLASS_MAP[model_type] + # For VLM models, use text_config for ModelArgs and map to base model type + config_for_args = config + model_class_type = model_type + + if model_type in VLM_TEXT_CONFIG_MAP: + # VLM models have text_config containing the language model config + text_config = config.get("text_config", {}) + if text_config: + # Merge text_config into a flat config for ModelArgs + config_for_args = {**config, **text_config} + # Also get num_hidden_layers from text_config if not in root + if "num_hidden_layers" not in config and "num_hidden_layers" in text_config: + config["num_hidden_layers"] = text_config["num_hidden_layers"] + model_class_type = VLM_TEXT_CONFIG_MAP[model_type] + logger.info(f"VLM model {model_type} using {model_class_type} ModelArgs with text_config") + + num_hidden_layers = config.get("num_hidden_layers", 0) + current_start_layer = self.start_layer if self.start_layer is not None else 0 + current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers + + if model_class_type in MODEL_CLASS_MAP: + model_class = MODEL_CLASS_MAP[model_class_type] else: - model_class = f"mlx_lm.models.{model_type}" + model_class = f"mlx_lm.models.{model_class_type}" try: arch_module = importlib.import_module(model_class) model_args_class = getattr(arch_module, "ModelArgs") - model_args = model_args_class.from_dict(config) + model_args = model_args_class.from_dict(config_for_args) except (ImportError, AttributeError) as e: - raise ValueError(f"Failed to load architecture for model_type '{model_type}'.") from e + raise ValueError(f"Failed to load architecture for model_type '{model_type}' (using {model_class}).") from e dtype = getattr(mx, config.get("torch_dtype", "bfloat16")) @@ -284,6 +360,23 @@ def load( model_id = model_id.split("/")[-1] else: # If it's already a clean name or a local path (take basename) model_id = pathlib.Path(model_id).name + + # Check for VLM model and get vision classes + vision_tower_class, projector_class, vision_config = _get_vlm_classes(model_type, config) + is_vlm = vision_config is not None and vision_tower_class is not None + + # Get VLM-specific config parameters + image_token_index = config.get("image_token_index") or config.get("image_token_id") + vision_feature_layer = config.get("vision_feature_layer", -2) + vision_feature_select_strategy = config.get("vision_feature_select_strategy", "default") + + if is_vlm: + logger.info( + f"Detected VLM model: {model_type}, " + f"image_token_index={image_token_index}, " + f"vision_feature_layer={vision_feature_layer}" + ) + model_shard = ShardedModel( config=model_args, model_id=model_id, @@ -291,6 +384,13 @@ def load( end_layer=current_end_layer, block_class=block_class, dtype=dtype, + # VLM parameters + vision_config=vision_config if is_vlm else None, + vision_tower_class=vision_tower_class, + multi_modal_projector_class=projector_class, + image_token_index=image_token_index, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, ) weight_files = glob.glob(str(model_path / "model*.safetensors")) @@ -322,6 +422,15 @@ def load( # loading only what we need. shard_weights = {} layer_key_prefix = "model.layers" # Common prefix + + # VLM weight prefixes to load on first shard + vlm_weight_prefixes = [ + "vision_tower.", + "vision_model.", + "visual.", # Some models use this prefix + "multi_modal_projector.", + "mm_projector.", # Alternative name + ] for file_idx, wf in enumerate(weight_files): logger.debug( @@ -361,6 +470,20 @@ def load( remapped_key = key.replace("model.", "", 1).replace( "embed_tokens", "lm_head" ) + + # VLM: Load vision tower and projector weights on first shard + if model_shard.is_first_shard and is_vlm: + for prefix in vlm_weight_prefixes: + if key.startswith(prefix): + is_needed = True + remapped_key = key + break + # Handle model.vision_tower.* style keys + if key.startswith(f"model.{prefix}"): + is_needed = True + remapped_key = key.replace("model.", "", 1) + break + if layer_key_prefix in key: try: parts = key.split(".") @@ -421,16 +544,25 @@ def class_predicate(p, m): model_shard.load_weights(list(shard_weights.items()), strict=strict) model_shard.shard_layers() + # Log VLM-specific weight loading info + if is_vlm and model_shard.is_first_shard: + vlm_weight_count = sum(1 for k in shard_weights.keys() + if any(k.startswith(p) for p in ["vision_tower", "vision_model", "visual", "multi_modal_projector", "mm_projector"])) + logger.info(f"Loaded {vlm_weight_count} VLM weights (vision_tower + projector)") + shard_weights.clear() mx.eval(model_shard.parameters()) # Synchronize processes to avoid timeout mx.eval(mx.distributed.all_sum(mx.array(1.0))) model_shard.eval() + + vlm_info = f", VLM={is_vlm}" if is_vlm else "" logger.info( - "Successfully loaded model shard (layers [%d-%d)), memory usage: %.3f GB", + "Successfully loaded model shard (layers [%d-%d)%s), memory usage: %.3f GB", current_start_layer, current_end_layer, + vlm_info, mx.get_active_memory() / 1024**3, ) return model_shard, config, tokenizer diff --git a/src/parallax/utils/vlm_utils.py b/src/parallax/utils/vlm_utils.py new file mode 100644 index 00000000..0b2e9092 --- /dev/null +++ b/src/parallax/utils/vlm_utils.py @@ -0,0 +1,282 @@ +""" +VLM (Vision Language Model) utilities for MLX backend. + +Provides image loading and preprocessing functionality for multimodal models. +""" + +import base64 +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + +# Lazy imports for optional dependencies +_PIL_AVAILABLE = None +_REQUESTS_AVAILABLE = None + + +def _check_pil(): + global _PIL_AVAILABLE + if _PIL_AVAILABLE is None: + try: + from PIL import Image + _PIL_AVAILABLE = True + except ImportError: + _PIL_AVAILABLE = False + return _PIL_AVAILABLE + + +def _check_requests(): + global _REQUESTS_AVAILABLE + if _REQUESTS_AVAILABLE is None: + try: + import requests + _REQUESTS_AVAILABLE = True + except ImportError: + _REQUESTS_AVAILABLE = False + return _REQUESTS_AVAILABLE + + +def load_image(source: Any) -> "Image.Image": + """ + Load an image from various sources. + + Supports: + - URL (http/https) + - Local file path + - Base64 encoded data URL + - PIL Image object (pass through) + - Dict with "url" key + + Args: + source: Image source (URL, path, base64, or PIL Image) + + Returns: + PIL Image in RGB format + + Raises: + ImportError: If PIL is not installed + ValueError: If source type is not supported + """ + if not _check_pil(): + raise ImportError("PIL (Pillow) is required for image loading. Install with: pip install Pillow") + + from PIL import Image + + # Handle dict with url key + if isinstance(source, dict): + source = source.get("url") + + # Pass through PIL Image + if isinstance(source, Image.Image): + return source.convert("RGB") + + if not isinstance(source, str): + raise ValueError(f"Unsupported image source type: {type(source)}") + + # Load from URL + if source.startswith("http://") or source.startswith("https://"): + if not _check_requests(): + raise ImportError("requests is required for URL loading. Install with: pip install requests") + import requests + response = requests.get(source, timeout=30) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + + # Load from base64 data URL + if source.startswith("data:image"): + # Format: data:image/png;base64, + header, encoded = source.split(",", 1) + image_data = base64.b64decode(encoded) + return Image.open(BytesIO(image_data)).convert("RGB") + + # Load from local file path + return Image.open(source).convert("RGB") + + +def process_images_for_vlm( + images: List[Any], + processor: Any, + text: str = "", + model_type: Optional[str] = None, +) -> Dict[str, Any]: + """ + Process images for VLM input using a HuggingFace processor. + + Args: + images: List of image sources (URLs, paths, base64, or PIL Images) + processor: HuggingFace processor (e.g., Qwen2VLProcessor, LlavaProcessor) + text: Text prompt to process alongside images + model_type: Optional model type for model-specific processing + + Returns: + Dictionary containing: + - pixel_values: numpy array of processed images + - image_grid_thw: (optional) grid sizes for dynamic resolution models + - image_sizes: (optional) original image sizes + - input_ids: (optional) tokenized input with image tokens + """ + if not images: + return {} + + # Load all images + pil_images = [] + image_sizes = [] + for img_src in images: + try: + img = load_image(img_src) + pil_images.append(img) + image_sizes.append((img.height, img.width)) + except Exception as e: + logger.warning(f"Failed to load image: {e}") + continue + + if not pil_images: + logger.warning("No images were successfully loaded") + return {} + + # Process with HuggingFace processor + try: + inputs = processor( + text=text, + images=pil_images, + return_tensors="np", # Use numpy for MLX compatibility + ) + except Exception as e: + logger.error(f"Image processing failed: {e}") + return {} + + result = {"image_sizes": image_sizes} + + # Extract pixel_values + if hasattr(inputs, "pixel_values"): + pixel_values = inputs.pixel_values + if hasattr(pixel_values, "numpy"): + pixel_values = pixel_values.numpy() + result["pixel_values"] = pixel_values + elif "pixel_values" in inputs: + pixel_values = inputs["pixel_values"] + if hasattr(pixel_values, "numpy"): + pixel_values = pixel_values.numpy() + result["pixel_values"] = pixel_values + + # Extract image_grid_thw (for Qwen-VL style models) + if hasattr(inputs, "image_grid_thw"): + grid_thw = inputs.image_grid_thw + if hasattr(grid_thw, "numpy"): + grid_thw = grid_thw.numpy() + result["image_grid_thw"] = grid_thw + elif "image_grid_thw" in inputs: + grid_thw = inputs["image_grid_thw"] + if hasattr(grid_thw, "numpy"): + grid_thw = grid_thw.numpy() + result["image_grid_thw"] = grid_thw + + # Extract input_ids if available (some processors expand image tokens) + if hasattr(inputs, "input_ids"): + input_ids = inputs.input_ids + if hasattr(input_ids, "numpy"): + input_ids = input_ids.numpy() + result["input_ids"] = input_ids.flatten().tolist() + elif "input_ids" in inputs: + input_ids = inputs["input_ids"] + if hasattr(input_ids, "numpy"): + input_ids = input_ids.numpy() + result["input_ids"] = input_ids.flatten().tolist() + + logger.debug( + f"Processed {len(pil_images)} images, " + f"pixel_values shape: {result.get('pixel_values', np.array([])).shape}" + ) + + return result + + +def create_vlm_inputs_from_request( + image_urls: List[str], + processor: Any, + text: str = "", + model_type: Optional[str] = None, +) -> Optional["VLMInputs"]: + """ + Create VLMInputs object from image URLs and processor. + + This is a convenience function that combines image processing + and VLMInputs creation. + + Args: + image_urls: List of image URLs or paths + processor: HuggingFace processor + text: Text prompt + model_type: Optional model type + + Returns: + VLMInputs object or None if processing fails + """ + from parallax.server.request import VLMInputs + + if not image_urls: + return None + + processed = process_images_for_vlm( + images=image_urls, + processor=processor, + text=text, + model_type=model_type, + ) + + if not processed or "pixel_values" not in processed: + return None + + return VLMInputs( + pixel_values=processed["pixel_values"], + image_grid_thw=processed.get("image_grid_thw"), + image_sizes=processed.get("image_sizes"), + images_processed=False, + ) + + +def get_image_token_count( + image_grid_thw: Optional[np.ndarray] = None, + image_size: Optional[Tuple[int, int]] = None, + patch_size: int = 14, + merge_size: int = 2, + temporal_patch_size: int = 2, +) -> int: + """ + Calculate the number of image tokens for a given image. + + Different VLM models have different token counting strategies: + - LLaVA: Fixed number based on image size / patch_size + - Qwen-VL: Dynamic based on image_grid_thw (temporal * height * width) + + Args: + image_grid_thw: Grid sizes (temporal, height, width) for Qwen-VL style + image_size: Image size (height, width) for LLaVA style + patch_size: Size of each image patch + merge_size: Merge factor for Qwen-VL (reduces tokens by merge_size^2) + temporal_patch_size: Temporal patch size for video + + Returns: + Number of image tokens + """ + if image_grid_thw is not None: + # Qwen-VL style: t * h * w tokens + if isinstance(image_grid_thw, np.ndarray): + t, h, w = image_grid_thw.flatten()[:3] + else: + t, h, w = image_grid_thw + # After merge: (t * h * w) / (merge_size^2) + return int((t * h * w) // (merge_size ** 2)) + + if image_size is not None: + # LLaVA style: (H / patch_size) * (W / patch_size) + h, w = image_size + return (h // patch_size) * (w // patch_size) + + # Default fallback + return 576 # Common default for 336x336 image with 14x14 patches From 0a67dadd5b49135ef7216d2e6310f1a7ac533b69 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 3 Feb 2026 17:19:20 +0800 Subject: [PATCH 09/36] load success --- src/parallax/server/executor/mlx_executor.py | 24 ++++- src/parallax/server/model.py | 76 ++++++++------- src/parallax/server/shard_loader.py | 97 +++++++++++++------- src/parallax/utils/weight_filter_utils.py | 78 ++++++++++++++-- 4 files changed, 199 insertions(+), 76 deletions(-) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 3ce6d82b..9bff0cf7 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -32,6 +32,14 @@ logger = get_logger(__name__) +def _get_config_value(config: Dict[str, Any], key: str, default: Any = None) -> Any: + """Get config value, falling back to text_config for VLM models.""" + if key in config and config[key] is not None: + return config[key] + text_config = config.get("text_config", {}) + return text_config.get(key, default) + + class MLXExecutor(BaseExecutor): def __init__( self, @@ -122,7 +130,7 @@ def __init__( self.model_shard = self.shard_loader.load_lora(self.model_shard, adapters) logger.debug( - f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={self.config.get('num_hidden_layers')}" + f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={_get_config_value(self.config, 'num_hidden_layers')}" ) # Load VLM processor if this is a VLM model (first peer only) @@ -147,10 +155,16 @@ def __init__( ) # Calculate feature dimensions for kv cache - num_key_value_heads = self.config.get("num_key_value_heads") - head_dim = self.config.get("head_dim") or self.config.get("hidden_size") // self.config.get( - "num_attention_heads" - ) + # Use helper to handle VLM models where these are in text_config + num_key_value_heads = _get_config_value(self.config, "num_key_value_heads") + head_dim = _get_config_value(self.config, "head_dim") + if head_dim is None: + hidden_size = _get_config_value(self.config, "hidden_size") + num_attention_heads = _get_config_value(self.config, "num_attention_heads") + if hidden_size and num_attention_heads: + head_dim = hidden_size // num_attention_heads + else: + raise ValueError(f"Cannot determine head_dim: hidden_size={hidden_size}, num_attention_heads={num_attention_heads}") qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) if qk_nope_head_dim is not None and qk_rope_head_dim is not None: diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index 4ca32b8b..c69e5d75 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -200,12 +200,14 @@ def get_input_embeddings( def _encode_images( self, pixel_values: mx.array, + image_grid_thw: Optional[mx.array] = None, **kwargs, ) -> mx.array: """Encode images through vision tower and projector. Args: pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W) + image_grid_thw: Grid size (T, H, W) for Qwen-VL models **kwargs: Additional model-specific arguments Returns: @@ -214,11 +216,9 @@ def _encode_images( if self.vision_tower is None: raise ValueError("Vision tower not initialized for this model") - # Convert to vision tower expected format (typically NHWC for MLX) - # pixel_values is usually in NCHW format from processor - if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: - # NCHW -> NHWC - pixel_values = pixel_values.transpose(0, 2, 3, 1) + # Check if this is a Qwen-VL style model (needs grid_thw) + model_type = getattr(self.vision_config, 'model_type', '') if self.vision_config else '' + is_qwen_vl = 'qwen' in model_type.lower() and 'vl' in model_type.lower() # Ensure correct dtype if hasattr(self.vision_tower, 'patch_embed') and hasattr(self.vision_tower.patch_embed, 'proj'): @@ -228,37 +228,51 @@ def _encode_images( pixel_values = pixel_values.astype(self.dtype) # Get vision features from vision tower - # Different vision models have different output formats - vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - - # Handle different output formats - if isinstance(vision_outputs, tuple): - # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states) - if len(vision_outputs) >= 3: - hidden_states = vision_outputs[2] # All hidden states - if isinstance(self.vision_feature_layer, int): - selected_features = hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - # Remove CLS token - selected_features = selected_features[:, 1:] + if is_qwen_vl and image_grid_thw is not None: + # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> (hidden_states, deepstack_features) + # No format conversion needed - Qwen-VL expects flat patches + vision_outputs = self.vision_tower(pixel_values, image_grid_thw) + if isinstance(vision_outputs, tuple): + # First element is the merged hidden states (already projected by merger) + selected_features = vision_outputs[0] + else: + selected_features = vision_outputs + else: + # Standard CLIP/SigLIP style + # Convert to vision tower expected format (typically NHWC for MLX) + if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: + # NCHW -> NHWC + pixel_values = pixel_values.transpose(0, 2, 3, 1) + + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + + # Handle different output formats + if isinstance(vision_outputs, tuple): + # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states) + if len(vision_outputs) >= 3: + hidden_states = vision_outputs[2] # All hidden states + if isinstance(self.vision_feature_layer, int): + selected_features = hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + # Remove CLS token + selected_features = selected_features[:, 1:] + else: + # Multiple layers + hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_features = mx.concatenate(hs_pool, axis=-1) else: - # Multiple layers - hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer] + # Simple (pooler, hidden_state) output + selected_features = vision_outputs[1] if self.vision_feature_select_strategy == "default": - hs_pool = [hs[:, 1:] for hs in hs_pool] - selected_features = mx.concatenate(hs_pool, axis=-1) + selected_features = selected_features[:, 1:] else: - # Simple (pooler, hidden_state) output - selected_features = vision_outputs[1] - if self.vision_feature_select_strategy == "default": - selected_features = selected_features[:, 1:] - else: - # Direct hidden state output (Qwen-VL style) - # These models already output projected features - selected_features = vision_outputs + # Direct hidden state output + selected_features = vision_outputs # Project to language model dimension if projector exists - # Some VLMs (e.g., Qwen-VL) have projection built into VisionModel + # Qwen-VL models have projection built into VisionModel's merger if self.multi_modal_projector is not None: image_features = self.multi_modal_projector(selected_features) else: diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 8854ac7b..ebbc7f3f 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -45,6 +45,14 @@ } +def _get_config_value(config: Dict[str, Any], key: str, default: Any = None) -> Any: + """Get config value, falling back to text_config for VLM models.""" + if key in config: + return config[key] + text_config = config.get("text_config", {}) + return text_config.get(key, default) + + def _get_vlm_classes( model_type: str, config: Dict[str, Any] ) -> Tuple[Optional[Type[nn.Module]], Optional[Type[nn.Module]], Optional[Dict[str, Any]]]: @@ -413,6 +421,7 @@ def load( is_first_shard=model_shard.is_first_shard, is_last_shard=model_shard.is_last_shard, config=config, + is_vlm=is_vlm, ) if not weight_files and strict: @@ -421,16 +430,26 @@ def load( # Instead of loading all weights, we iterate through files and keys, # loading only what we need. shard_weights = {} - layer_key_prefix = "model.layers" # Common prefix + + # Layer key prefixes to check (in order of priority) + # VLM models like Qwen3-VL use model.language_model.layers.X + # Standard models use model.layers.X + layer_key_prefixes = [ + ("model.language_model.layers.", 3), # VLM style: parts[3] is layer index + ("model.layers.", 2), # Standard style: parts[2] is layer index + ] # VLM weight prefixes to load on first shard vlm_weight_prefixes = [ "vision_tower.", "vision_model.", - "visual.", # Some models use this prefix + "visual.", # Qwen-VL style "multi_modal_projector.", - "mm_projector.", # Alternative name + "mm_projector.", ] + + # Get tie_word_embeddings config (check both root and text_config for VLM) + tie_word_embeddings = _get_config_value(config, "tie_word_embeddings", False) for file_idx, wf in enumerate(weight_files): logger.debug( @@ -443,33 +462,38 @@ def load( remapped_key = None # Check if the key belongs to the shard and remap it - if ( - model_shard.is_first_shard - and "embed_tokens" in key - and key.startswith("model.") - ): + # Embeddings: model.embed_tokens or model.language_model.embed_tokens + if model_shard.is_first_shard and "embed_tokens" in key: is_needed = True - remapped_key = key.replace("model.", "", 1) - if model_shard.is_last_shard and config.get("tie_word_embeddings", False): - # Also add lm_head mapping for tied embeddings + # Remap to just embed_tokens.weight + if "language_model.embed_tokens" in key: + remapped_key = key.split("language_model.")[-1] + elif key.startswith("model."): + remapped_key = key.replace("model.", "", 1) + else: + remapped_key = key + if model_shard.is_last_shard and tie_word_embeddings: lm_head_key = remapped_key.replace("embed_tokens", "lm_head") shard_weights[lm_head_key] = f[key] + elif model_shard.is_last_shard: - if "model.norm" in key: - is_needed = True - remapped_key = key.replace("model.", "", 1) + # Final norm: model.norm or model.language_model.norm + if ".norm." in key or key.endswith(".norm.weight"): + if "language_model.norm" in key or (key.startswith("model.norm") and "layers" not in key): + is_needed = True + if "language_model.norm" in key: + remapped_key = key.split("language_model.")[-1] + else: + remapped_key = key.replace("model.", "", 1) if "lm_head" in key: is_needed = True remapped_key = key - elif ( - config.get("tie_word_embeddings", False) - and "embed_tokens" in key - and key.startswith("model.embed_tokens") - ): + elif tie_word_embeddings and "embed_tokens" in key: is_needed = True - remapped_key = key.replace("model.", "", 1).replace( - "embed_tokens", "lm_head" - ) + if "language_model.embed_tokens" in key: + remapped_key = key.split("language_model.")[-1].replace("embed_tokens", "lm_head") + else: + remapped_key = key.replace("model.", "", 1).replace("embed_tokens", "lm_head") # VLM: Load vision tower and projector weights on first shard if model_shard.is_first_shard and is_vlm: @@ -478,22 +502,29 @@ def load( is_needed = True remapped_key = key break - # Handle model.vision_tower.* style keys + # Handle model.vision_tower.*, model.visual.* style keys if key.startswith(f"model.{prefix}"): is_needed = True + # Keep as vision_tower.* or visual.* (remove model. prefix) remapped_key = key.replace("model.", "", 1) break - if layer_key_prefix in key: - try: - parts = key.split(".") - layer_idx = int(parts[2]) - if current_start_layer <= layer_idx < current_end_layer: - is_needed = True - local_layer_idx = layer_idx - current_start_layer - remapped_key = f"layers.{local_layer_idx}.{'.'.join(parts[3:])}" - except (ValueError, IndexError): - continue + # Check layer keys with multiple prefix patterns + if not is_needed: + for layer_prefix, layer_idx_pos in layer_key_prefixes: + if layer_prefix in key: + try: + parts = key.split(".") + layer_idx = int(parts[layer_idx_pos]) + if current_start_layer <= layer_idx < current_end_layer: + is_needed = True + local_layer_idx = layer_idx - current_start_layer + # Remap to layers.{local_idx}.{rest} + rest_parts = parts[layer_idx_pos + 1:] + remapped_key = f"layers.{local_layer_idx}.{'.'.join(rest_parts)}" + break + except (ValueError, IndexError): + continue # If the key is needed, load only that tensor from the file if is_needed: diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index a8f21dbb..f46659e1 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -6,6 +6,20 @@ logger = logging.getLogger(__name__) +def _get_num_hidden_layers(config: Dict) -> int: + """Get num_hidden_layers from config, handling VLM models with text_config.""" + if "num_hidden_layers" in config: + return config["num_hidden_layers"] + # VLM models store this in text_config + text_config = config.get("text_config", {}) + return text_config.get("num_hidden_layers", 0) + + +def _is_vlm_model(config: Dict) -> bool: + """Check if config represents a VLM model.""" + return config.get("vision_config") is not None + + def should_include_weight_key( key: str, start_layer: int, @@ -13,22 +27,46 @@ def should_include_weight_key( is_first_shard: bool, is_last_shard: bool, tie_word_embeddings: bool = False, + is_vlm: bool = False, ) -> bool: - if is_first_shard and "embed_tokens" in key and key.startswith("model."): + # Embeddings on first shard + # Handles: model.embed_tokens, model.language_model.embed_tokens + if is_first_shard and "embed_tokens" in key: return True + + # VLM: Include vision components on first shard + # Handles various naming conventions: + # - vision_tower.*, model.vision_tower.* + # - model.visual.*, visual.* (Qwen-VL style) + # - multi_modal_projector.*, mm_projector.* + if is_first_shard and is_vlm: + vlm_prefixes = [ + "vision_tower", "vision_model", "visual", + "multi_modal_projector", "mm_projector", + ] + for prefix in vlm_prefixes: + if key.startswith(prefix) or key.startswith(f"model.{prefix}"): + return True + # Final norm and lm_head on last shard + # Handles: model.norm, model.language_model.norm, lm_head if is_last_shard: - if "model.norm" in key or "lm_head" in key: + if ".norm." in key or key.endswith(".norm.weight") or "lm_head" in key: return True - if tie_word_embeddings and "embed" in key and key.startswith("model.embed_tokens"): + if tie_word_embeddings and "embed_tokens" in key: return True + # Transformer layers - check layer index + # Handles: model.layers.X, model.language_model.layers.X if "layers." in key: parts = key.split(".") for i, part in enumerate(parts): if part == "layers" and i + 1 < len(parts): - layer_idx = int(parts[i + 1]) - return start_layer <= layer_idx < end_layer + try: + layer_idx = int(parts[i + 1]) + return start_layer <= layer_idx < end_layer + except ValueError: + continue return False @@ -41,6 +79,7 @@ def filter_weight_files_by_layer_range_for_load( is_first_shard: bool, is_last_shard: bool, config: Optional[Dict] = None, + is_vlm: bool = False, ) -> List[str]: index_file = model_path / "model.safetensors.index.json" @@ -59,12 +98,19 @@ def filter_weight_files_by_layer_range_for_load( tie_word_embeddings = False if config: tie_word_embeddings = config.get("tie_word_embeddings", False) + # Also check text_config for VLM models + if not tie_word_embeddings: + text_config = config.get("text_config", {}) + tie_word_embeddings = text_config.get("tie_word_embeddings", False) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) tie_word_embeddings = cfg.get("tie_word_embeddings", False) + if not tie_word_embeddings: + text_config = cfg.get("text_config", {}) + tie_word_embeddings = text_config.get("tie_word_embeddings", False) needed_files: Set[str] = set() @@ -78,6 +124,7 @@ def filter_weight_files_by_layer_range_for_load( is_first_shard=is_first_shard, is_last_shard=is_last_shard, tie_word_embeddings=tie_word_embeddings, + is_vlm=is_vlm, ): needed_files.add(filename) @@ -97,6 +144,15 @@ def filter_weight_files_by_layer_range_for_load( f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " f"for layers [{start_layer}, {end_layer})" ) + + # If filtering resulted in no files but we had input files, + # fall back to original files (file naming may differ from index) + if not filtered_files and weight_files: + logger.debug( + f"Filtering resulted in no files, falling back to all {len(weight_files)} weight files. " + f"needed_files={needed_files}, input_files={[Path(wf).name for wf in weight_files]}" + ) + return weight_files return filtered_files @@ -110,16 +166,19 @@ def determine_needed_weight_files_for_download( is_first_shard = start_layer == 0 is_last_shard = False + is_vlm = False if config: - num_hidden_layers = config.get("num_hidden_layers", 0) + num_hidden_layers = _get_num_hidden_layers(config) is_last_shard = end_layer >= num_hidden_layers + is_vlm = _is_vlm_model(config) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) - num_hidden_layers = cfg.get("num_hidden_layers", 0) + num_hidden_layers = _get_num_hidden_layers(cfg) is_last_shard = end_layer >= num_hidden_layers + is_vlm = _is_vlm_model(cfg) index_file = model_path / "model.safetensors.index.json" @@ -150,6 +209,10 @@ def determine_needed_weight_files_for_download( tie_word_embeddings = False if config: tie_word_embeddings = config.get("tie_word_embeddings", False) + # Also check text_config for VLM models + if not tie_word_embeddings: + text_config = config.get("text_config", {}) + tie_word_embeddings = text_config.get("tie_word_embeddings", False) needed_files: Set[str] = set() @@ -163,6 +226,7 @@ def determine_needed_weight_files_for_download( is_first_shard=is_first_shard, is_last_shard=is_last_shard, tie_word_embeddings=tie_word_embeddings, + is_vlm=is_vlm, ): needed_files.add(filename) From 8f7ce009b30843d4f1c0dc42a10cdc3766d256cc Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 4 Feb 2026 11:34:08 +0800 Subject: [PATCH 10/36] run success on mlx --- pyproject.toml | 1 + src/parallax/models/qwen3_vl.py | 9 +++ src/parallax/server/executor/base_executor.py | 39 ++++++------ src/parallax/server/executor/mlx_executor.py | 38 ++++++++++-- src/parallax/server/request.py | 6 +- src/parallax/server/shard_loader.py | 59 +++++++++++++++---- src/parallax/utils/weight_filter_utils.py | 12 ++-- 7 files changed, 123 insertions(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 52bf2534..7faa53d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "dijkstar==2.6.0", "lattica==1.0.21", "orjson", + "torchvision==0.23.0" ] [project.scripts] diff --git a/src/parallax/models/qwen3_vl.py b/src/parallax/models/qwen3_vl.py index 0ee4ff87..75723129 100644 --- a/src/parallax/models/qwen3_vl.py +++ b/src/parallax/models/qwen3_vl.py @@ -78,6 +78,15 @@ def __call__( cos, sin = self.rotary_emb(values, pos_ids) queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + # Ensure dtype consistency with cache (RoPE may output float32) + cache_dtype = key_cache_global.dtype + if keys.dtype != cache_dtype: + keys = keys.astype(cache_dtype) + if values.dtype != cache_dtype: + values = values.astype(cache_dtype) + if queries.dtype != cache_dtype: + queries = queries.astype(cache_dtype) + # Cache update with PagedAttention block_size = key_cache_global.shape[3] reshape_and_cache( diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 68b48201..adf16fd2 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -727,10 +727,12 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): ) # Use processor to handle text + images together + # Note: Qwen processors only support return_tensors="pt" + # We keep PyTorch tensors directly - mx.array() can convert them later processor_inputs = self.vlm_processor( text=text_prompt, images=images, - return_tensors="np", + return_tensors="pt", ) # Extract input_ids @@ -738,18 +740,16 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): if input_ids is None: raise ValueError("Processor did not return input_ids") - if hasattr(input_ids, "numpy"): - input_ids = input_ids.numpy() - prompt = input_ids.flatten().tolist() + # Convert to list for token IDs (need CPU for this) + if hasattr(input_ids, "tolist"): + prompt = input_ids.flatten().tolist() + else: + prompt = input_ids.flatten().tolist() - # Extract pixel_values and other vision inputs + # Extract pixel_values and image_grid_thw - keep as PyTorch tensors + # mx.array() can convert PyTorch tensors directly pixel_values = processor_inputs.get("pixel_values") - if pixel_values is not None and hasattr(pixel_values, "numpy"): - pixel_values = pixel_values.numpy() - image_grid_thw = processor_inputs.get("image_grid_thw") - if image_grid_thw is not None and hasattr(image_grid_thw, "numpy"): - image_grid_thw = image_grid_thw.numpy() # Create VLMInputs vlm_inputs = VLMInputs( @@ -775,27 +775,28 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): def _format_messages_for_vlm(self, messages: list) -> list: """ Format messages for VLM processing. - Converts image_url content parts to a format the processor understands. + Keep the original structure so chat template can handle image placeholders correctly. """ formatted = [] for msg in messages: content = msg.get("content") if isinstance(content, list): - # Convert content list to text with image placeholders - text_parts = [] + # Keep content as list format for chat template to process + # Chat templates like Qwen3-VL expect content list with image_url parts + new_content = [] for part in content: if isinstance(part, dict): if part.get("type") == "text": - text_parts.append(part.get("text", "")) + new_content.append({"type": "text", "text": part.get("text", "")}) elif part.get("type") == "image_url": - # Add image placeholder - processor will handle the actual insertion - # Different models use different placeholders - text_parts.append("") + # Mark as image for chat template + # Qwen3-VL chat template looks for 'image_url' or 'image' in content + new_content.append({"type": "image"}) elif isinstance(part, str): - text_parts.append(part) + new_content.append({"type": "text", "text": part}) formatted.append({ "role": msg.get("role", "user"), - "content": "".join(text_parts), + "content": new_content, }) else: formatted.append(msg) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 9bff0cf7..2db63676 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -137,15 +137,45 @@ def __init__( self.vlm_processor = None self.model_type = self.config.get("model_type") if hasattr(self.model_shard, 'is_vlm') and self.model_shard.is_vlm and start_layer == 0: + processor_path = self.shard_loader.model_path_str + logger.debug(f"Trying to load VLM processor from: {processor_path}") + + processor_loaded = False + + try: from transformers import AutoProcessor self.vlm_processor = AutoProcessor.from_pretrained( - model_repo, - trust_remote_code=True + processor_path, + trust_remote_code=True, ) - logger.info(f"Loaded VLM processor for {self.model_type}") + processor_type = type(self.vlm_processor).__name__ + # Verify it has image processing capability + if hasattr(self.vlm_processor, 'image_processor') and self.vlm_processor.image_processor is not None: + logger.info(f"Loaded VLM processor (AutoProcessor -> {processor_type}) for {self.model_type}") + processor_loaded = True + else: + logger.warning(f"AutoProcessor loaded {processor_type} but it doesn't have image_processor, skipping") + self.vlm_processor = None except Exception as e: - logger.warning(f"Failed to load VLM processor: {e}. VLM image processing will be disabled.") + import traceback + logger.debug(f"AutoProcessor failed: {e}") + if not processor_loaded: + try: + # Must import torch first to avoid flex_attention import errors in transformers + import torch + from transformers import Qwen2VLProcessor + self.vlm_processor = Qwen2VLProcessor.from_pretrained( + processor_path, + trust_remote_code=True + ) + logger.info(f"Loaded VLM processor (Qwen2VLProcessor) for {self.model_type}") + processor_loaded = True + except Exception as e: + logger.debug(f"Qwen2VLProcessor failed: {e}") + + if not processor_loaded: + logger.warning("VLM image processing will be disabled - no processor could be loaded.") # TODO: Duplicate code to BaseExecutor since num_shard_layers and dtype are needed for initializing kv cache self.num_shard_layers = end_layer - start_layer diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 4735ee07..bd6b1e16 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -82,12 +82,14 @@ class VLMInputs: # Preprocessed image tensor, shape varies by model: # - LLaVA: (num_images, C, H, W) or (num_patches, C, patch_H, patch_W) # - Qwen-VL: (num_patches, C, patch_H, patch_W) with temporal dim for video - pixel_values: Optional[np.ndarray] = None + # Can be numpy array or PyTorch tensor - mx.array() can convert both + pixel_values: Optional[Any] = None # For models with dynamic resolution (e.g., Qwen2-VL): # Tuple of (temporal, height, width) grid sizes for each image # Shape: (num_images, 3) where each row is (t, h, w) - image_grid_thw: Optional[np.ndarray] = None + # Can be numpy array or PyTorch tensor + image_grid_thw: Optional[Any] = None # Number of image tokens per image (for variable-length image tokens) image_token_counts: Optional[List[int]] = None diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index ebbc7f3f..72f0235b 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -432,10 +432,13 @@ def load( shard_weights = {} # Layer key prefixes to check (in order of priority) - # VLM models like Qwen3-VL use model.language_model.layers.X - # Standard models use model.layers.X + # Different model formats use different key prefixes: + # - language_model.model.layers.X (mlx-vlm converted format) + # - model.language_model.layers.X (HuggingFace VLM format) + # - model.layers.X (Standard LLM format) layer_key_prefixes = [ - ("model.language_model.layers.", 3), # VLM style: parts[3] is layer index + ("language_model.model.layers.", 3), # mlx-vlm style: parts[3] is layer index + ("model.language_model.layers.", 3), # HF VLM style: parts[3] is layer index ("model.layers.", 2), # Standard style: parts[2] is layer index ] @@ -462,11 +465,17 @@ def load( remapped_key = None # Check if the key belongs to the shard and remap it - # Embeddings: model.embed_tokens or model.language_model.embed_tokens + # Embeddings: Various formats: + # - language_model.model.embed_tokens.* (mlx-vlm converted) + # - model.language_model.embed_tokens.* (HF VLM) + # - model.embed_tokens.* (standard) if model_shard.is_first_shard and "embed_tokens" in key: is_needed = True - # Remap to just embed_tokens.weight - if "language_model.embed_tokens" in key: + # Remap to just embed_tokens.* + if "language_model.model.embed_tokens" in key: + # mlx-vlm format: language_model.model.embed_tokens.weight -> embed_tokens.weight + remapped_key = key.replace("language_model.model.", "") + elif "language_model.embed_tokens" in key: remapped_key = key.split("language_model.")[-1] elif key.startswith("model."): remapped_key = key.replace("model.", "", 1) @@ -477,20 +486,36 @@ def load( shard_weights[lm_head_key] = f[key] elif model_shard.is_last_shard: - # Final norm: model.norm or model.language_model.norm + # Final norm: Various formats + # - language_model.model.norm.* (mlx-vlm converted) + # - model.language_model.norm.* (HF VLM) + # - model.norm.* (standard) if ".norm." in key or key.endswith(".norm.weight"): - if "language_model.norm" in key or (key.startswith("model.norm") and "layers" not in key): + is_final_norm = ( + "language_model.model.norm" in key or + "language_model.norm" in key or + (key.startswith("model.norm") and "layers" not in key) + ) + if is_final_norm: is_needed = True - if "language_model.norm" in key: + if "language_model.model.norm" in key: + remapped_key = key.replace("language_model.model.", "") + elif "language_model.norm" in key: remapped_key = key.split("language_model.")[-1] else: remapped_key = key.replace("model.", "", 1) if "lm_head" in key: is_needed = True - remapped_key = key + # Handle language_model.lm_head.* format + if key.startswith("language_model."): + remapped_key = key.replace("language_model.", "") + else: + remapped_key = key elif tie_word_embeddings and "embed_tokens" in key: is_needed = True - if "language_model.embed_tokens" in key: + if "language_model.model.embed_tokens" in key: + remapped_key = key.replace("language_model.model.", "").replace("embed_tokens", "lm_head") + elif "language_model.embed_tokens" in key: remapped_key = key.split("language_model.")[-1].replace("embed_tokens", "lm_head") else: remapped_key = key.replace("model.", "", 1).replace("embed_tokens", "lm_head") @@ -572,7 +597,15 @@ def class_predicate(p, m): class_predicate=class_predicate, ) - model_shard.load_weights(list(shard_weights.items()), strict=strict) + # Log weight keys before loading + logger.info(f"Loading {len(shard_weights)} weights. Sample keys: {list(shard_weights.keys())[:20]}") + + # Try strict mode first to catch any mismatch, then fall back to non-strict + try: + model_shard.load_weights(list(shard_weights.items()), strict=True) + except Exception as e: + logger.warning(f"Strict weight loading failed: {e}. Retrying with strict=False.") + model_shard.load_weights(list(shard_weights.items()), strict=False) model_shard.shard_layers() # Log VLM-specific weight loading info @@ -580,6 +613,8 @@ def class_predicate(p, m): vlm_weight_count = sum(1 for k in shard_weights.keys() if any(k.startswith(p) for p in ["vision_tower", "vision_model", "visual", "multi_modal_projector", "mm_projector"])) logger.info(f"Loaded {vlm_weight_count} VLM weights (vision_tower + projector)") + + logger.info(f"Total weights loaded: {len(shard_weights)}") shard_weights.clear() diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index f46659e1..1e1d9edd 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -30,7 +30,7 @@ def should_include_weight_key( is_vlm: bool = False, ) -> bool: # Embeddings on first shard - # Handles: model.embed_tokens, model.language_model.embed_tokens + # Handles: model.embed_tokens, model.language_model.embed_tokens, language_model.model.embed_tokens if is_first_shard and "embed_tokens" in key: return True @@ -49,15 +49,19 @@ def should_include_weight_key( return True # Final norm and lm_head on last shard - # Handles: model.norm, model.language_model.norm, lm_head + # Handles: model.norm, model.language_model.norm, language_model.model.norm, lm_head if is_last_shard: - if ".norm." in key or key.endswith(".norm.weight") or "lm_head" in key: + # Check for final norm (not layer norms) + if ("lm_head" in key) or ( + (".norm." in key or key.endswith(".norm.weight")) and + "layers" not in key + ): return True if tie_word_embeddings and "embed_tokens" in key: return True # Transformer layers - check layer index - # Handles: model.layers.X, model.language_model.layers.X + # Handles: model.layers.X, model.language_model.layers.X, language_model.model.layers.X if "layers." in key: parts = key.split(".") for i, part in enumerate(parts): From 96fbf6a9110ebb791bb4b273caffd7a01e130029 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 4 Feb 2026 16:22:00 +0800 Subject: [PATCH 11/36] pre-commit --- pyproject.toml | 5 +- src/parallax/launch.py | 2 +- src/parallax/models/qwen3_vl.py | 51 ++++--- src/parallax/server/executor/base_executor.py | 64 +++++---- src/parallax/server/executor/mlx_executor.py | 49 ++++--- .../server/executor/sglang_executor.py | 6 +- src/parallax/server/model.py | 130 ++++++++++-------- src/parallax/server/request.py | 27 ++-- src/parallax/server/shard_loader.py | 102 +++++++++----- src/parallax/sglang/batch_info.py | 12 +- src/parallax/sglang/model_runner.py | 1 + src/parallax/sglang/multimodal_utils.py | 127 ++++++++--------- src/parallax/utils/vlm_utils.py | 83 +++++------ src/parallax/utils/weight_filter_utils.py | 14 +- 14 files changed, 361 insertions(+), 312 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7faa53d2..cd7e7f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "dijkstar==2.6.0", "lattica==1.0.21", "orjson", - "torchvision==0.23.0" + ] [project.scripts] @@ -49,20 +49,19 @@ mac = [ "mlx-lm==0.30.5", "mlx==0.30.4", "mlx-vlm==0.3.10", + "torchvision==0.23.0" ] gpu = [ "sglang[all]==0.5.7", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", - "mlx-vlm==0.3.10", ] vllm = [ "vllm==0.14.0", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", - "mlx-vlm==0.3.10", ] benchmark = [ diff --git a/src/parallax/launch.py b/src/parallax/launch.py index a62892f3..6959ca1d 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -121,7 +121,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache) num_layers = config.get("num_hidden_layers") - + # If not found in top level, check text_config (common in multimodal models) if num_layers is None and "text_config" in config: text_config = config["text_config"] diff --git a/src/parallax/models/qwen3_vl.py b/src/parallax/models/qwen3_vl.py index 75723129..de294550 100644 --- a/src/parallax/models/qwen3_vl.py +++ b/src/parallax/models/qwen3_vl.py @@ -10,25 +10,21 @@ import mlx.core as mx from mlx import nn +# Import from mlx-vlm +from mlx_vlm.models.qwen3_vl.language import MLP +from mlx_vlm.models.qwen3_vl.language import Attention as MLXQwen3VLAttention +from mlx_vlm.models.qwen3_vl.language import apply_multimodal_rotary_pos_emb + from parallax.server.cache.base import BaseCache from parallax_extensions.ops import paged_attention_v1, reshape_and_cache from parallax_utils.logging_config import get_logger -# Import from mlx-vlm -from mlx_vlm.models.qwen3_vl.language import ( - Attention as MLXQwen3VLAttention, - MLP, - Qwen3VLDecoderLayer as MLXQwen3VLDecoderLayer, - Qwen3VLRotaryEmbedding, - apply_multimodal_rotary_pos_emb, -) - logger = get_logger(__name__) class ParallaxQwen3VLAttention(MLXQwen3VLAttention): """Qwen3VL Attention with PagedAttention support for Parallax.""" - + def __call__( self, x: mx.array, @@ -42,20 +38,18 @@ def __call__( **kwargs, ) -> mx.array: B, L, D = x.shape - + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - queries = self.q_norm( - queries.reshape(B, L, self.n_heads, self.head_dim) - ).transpose(0, 2, 1, 3) - keys = self.k_norm( - keys.reshape(B, L, self.n_kv_heads, self.head_dim) - ).transpose(0, 2, 1, 3) + + queries = self.q_norm(queries.reshape(B, L, self.n_heads, self.head_dim)).transpose( + 0, 2, 1, 3 + ) + keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, self.head_dim)).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, self.head_dim) - + # Get KV cache key_cache_global, value_cache_global = cache.get_cache() - + # Compute RoPE position if L == 1: # Decode phase: use context_lengths - 1 as offset @@ -74,10 +68,10 @@ def __call__( pos_ids = mx.arange(L)[None, :] pos_ids = mx.broadcast_to(pos_ids, (B, L)) pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) - + cos, sin = self.rotary_emb(values, pos_ids) queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) - + # Ensure dtype consistency with cache (RoPE may output float32) cache_dtype = key_cache_global.dtype if keys.dtype != cache_dtype: @@ -86,7 +80,7 @@ def __call__( values = values.astype(cache_dtype) if queries.dtype != cache_dtype: queries = queries.astype(cache_dtype) - + # Cache update with PagedAttention block_size = key_cache_global.shape[3] reshape_and_cache( @@ -99,7 +93,7 @@ def __call__( block_size, slot_mapping=slot_mapping, ) - + # Compute attention if L == 1: # Decode: use PagedAttention @@ -117,6 +111,7 @@ def __call__( else: # Prefill: standard attention from mlx_lm.models.base import scaled_dot_product_attention + output = scaled_dot_product_attention( queries, keys, @@ -126,13 +121,13 @@ def __call__( cache=None, ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - + return self.o_proj(output) class ParallaxQwen3VLBlock(nn.Module): """Qwen3VL Transformer block with PagedAttention support.""" - + def __init__(self, args, layer_idx: int, local_layer_idx: int): super().__init__() self.hidden_size = args.hidden_size @@ -142,7 +137,7 @@ def __init__(self, args, layer_idx: int, local_layer_idx: int): self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.layer_idx = layer_idx self.local_layer_idx = local_layer_idx - + def __call__( self, x: mx.array, @@ -166,7 +161,7 @@ def __call__( r = self.mlp(self.post_attention_layernorm(h)) out = h + r return out - + @classmethod def get_architecture(cls): """Get the architecture name for the block.""" diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index adf16fd2..c7de2b8c 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -45,7 +45,6 @@ from parallax.server.scheduler import Scheduler from parallax.utils.shared_state import SharedState from parallax.utils.utils import get_current_device, get_device_dtype, get_zmq_socket -from parallax.utils.vlm_utils import create_vlm_inputs_from_request from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -53,6 +52,7 @@ class BaseExecutor: """High-level executor for managing model shards, scheduler, and cache pool on each Peer.""" + def __init__( self, # Model Configs @@ -126,7 +126,7 @@ def _config_get(key, default=None): num_hidden_layers = _config_get("n_layer") or _config_get("num_layers") if isinstance(text_config, dict): num_hidden_layers = text_config.get("num_hidden_layers") - + self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == num_hidden_layers self.tp_size = tp_size @@ -671,21 +671,21 @@ def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> else: prompt = convert_chat(messages, raw_request.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - + return prompt - + def _process_vlm_request(self, rid: str, messages: list, image_urls: list): """ Process a VLM (multimodal) request using the VLM processor. - + The processor handles both text formatting and image preprocessing together, ensuring proper image token insertion and expansion. - + Returns: Tuple of (input_ids, VLMInputs) """ from parallax.utils.vlm_utils import load_image - + try: # Load images images = [] @@ -695,17 +695,19 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): images.append(img) except Exception as e: logger.warning(f"Failed to load image {url}: {e}") - + if not images: - logger.warning(f"No images loaded for VLM request {rid}, falling back to text processing") + logger.warning( + f"No images loaded for VLM request {rid}, falling back to text processing" + ) return self._process_text_request(rid, messages, {}), None - + # Format messages for processor # Most VLM processors expect a specific format with image placeholders formatted_messages = self._format_messages_for_vlm(messages) - + # Apply chat template to get the text prompt - if hasattr(self.vlm_processor, 'apply_chat_template'): + if hasattr(self.vlm_processor, "apply_chat_template"): # Some processors have their own chat template text_prompt = self.vlm_processor.apply_chat_template( formatted_messages, @@ -725,7 +727,7 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" for msg in formatted_messages ) - + # Use processor to handle text + images together # Note: Qwen processors only support return_tensors="pt" # We keep PyTorch tensors directly - mx.array() can convert them later @@ -734,23 +736,23 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): images=images, return_tensors="pt", ) - + # Extract input_ids input_ids = processor_inputs.get("input_ids") if input_ids is None: raise ValueError("Processor did not return input_ids") - + # Convert to list for token IDs (need CPU for this) if hasattr(input_ids, "tolist"): prompt = input_ids.flatten().tolist() else: prompt = input_ids.flatten().tolist() - + # Extract pixel_values and image_grid_thw - keep as PyTorch tensors # mx.array() can convert PyTorch tensors directly pixel_values = processor_inputs.get("pixel_values") image_grid_thw = processor_inputs.get("image_grid_thw") - + # Create VLMInputs vlm_inputs = VLMInputs( pixel_values=pixel_values, @@ -758,20 +760,20 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): image_sizes=[(img.height, img.width) for img in images], images_processed=False, ) - + logger.debug( f"VLM request {rid}: {len(images)} images, " f"input_ids length={len(prompt)}, " f"pixel_values shape={pixel_values.shape if pixel_values is not None else None}" ) - + return prompt, vlm_inputs - + except Exception as e: logger.error(f"Failed to process VLM request {rid}: {e}") # Fall back to text-only processing return self._process_text_request(rid, messages, {}), None - + def _format_messages_for_vlm(self, messages: list) -> list: """ Format messages for VLM processing. @@ -794,14 +796,16 @@ def _format_messages_for_vlm(self, messages: list) -> list: new_content.append({"type": "image"}) elif isinstance(part, str): new_content.append({"type": "text", "text": part}) - formatted.append({ - "role": msg.get("role", "user"), - "content": new_content, - }) + formatted.append( + { + "role": msg.get("role", "user"), + "content": new_content, + } + ) else: formatted.append(msg) return formatted - + def _extract_text_from_content(self, content) -> str: """Extract text from message content (handles both string and list formats).""" if isinstance(content, str): @@ -821,7 +825,7 @@ def _handle_raw_request(self, raw_request: Dict): rid = raw_request["rid"] messages = raw_request["messages"] - + # Extract image URLs first to determine if this is a multimodal request multimodal_params = None image_urls = [] @@ -839,11 +843,11 @@ def _handle_raw_request(self, raw_request: Dict): image_urls.append(image_url.get("url", image_url)) else: image_urls.append(image_url) - + # Process multimodal request with VLM processor vlm_inputs = None - has_vlm_processor = hasattr(self, 'vlm_processor') and self.vlm_processor is not None - + has_vlm_processor = hasattr(self, "vlm_processor") and self.vlm_processor is not None + if image_urls and has_vlm_processor: # Use VLM processor to handle both text and images together prompt, vlm_inputs = self._process_vlm_request(rid, messages, image_urls) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index 2db63676..a66a3688 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -132,50 +132,57 @@ def __init__( logger.debug( f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={_get_config_value(self.config, 'num_hidden_layers')}" ) - + # Load VLM processor if this is a VLM model (first peer only) self.vlm_processor = None self.model_type = self.config.get("model_type") - if hasattr(self.model_shard, 'is_vlm') and self.model_shard.is_vlm and start_layer == 0: + if hasattr(self.model_shard, "is_vlm") and self.model_shard.is_vlm and start_layer == 0: processor_path = self.shard_loader.model_path_str logger.debug(f"Trying to load VLM processor from: {processor_path}") - + processor_loaded = False - try: from transformers import AutoProcessor + self.vlm_processor = AutoProcessor.from_pretrained( - processor_path, + processor_path, trust_remote_code=True, ) processor_type = type(self.vlm_processor).__name__ # Verify it has image processing capability - if hasattr(self.vlm_processor, 'image_processor') and self.vlm_processor.image_processor is not None: - logger.info(f"Loaded VLM processor (AutoProcessor -> {processor_type}) for {self.model_type}") + if ( + hasattr(self.vlm_processor, "image_processor") + and self.vlm_processor.image_processor is not None + ): + logger.info( + f"Loaded VLM processor (AutoProcessor -> {processor_type}) for {self.model_type}" + ) processor_loaded = True else: - logger.warning(f"AutoProcessor loaded {processor_type} but it doesn't have image_processor, skipping") + logger.warning( + f"AutoProcessor loaded {processor_type} but it doesn't have image_processor, skipping" + ) self.vlm_processor = None except Exception as e: - import traceback logger.debug(f"AutoProcessor failed: {e}") if not processor_loaded: try: # Must import torch first to avoid flex_attention import errors in transformers - import torch from transformers import Qwen2VLProcessor + self.vlm_processor = Qwen2VLProcessor.from_pretrained( - processor_path, - trust_remote_code=True + processor_path, trust_remote_code=True ) logger.info(f"Loaded VLM processor (Qwen2VLProcessor) for {self.model_type}") processor_loaded = True except Exception as e: logger.debug(f"Qwen2VLProcessor failed: {e}") - + if not processor_loaded: - logger.warning("VLM image processing will be disabled - no processor could be loaded.") + logger.warning( + "VLM image processing will be disabled - no processor could be loaded." + ) # TODO: Duplicate code to BaseExecutor since num_shard_layers and dtype are needed for initializing kv cache self.num_shard_layers = end_layer - start_layer @@ -194,7 +201,9 @@ def __init__( if hidden_size and num_attention_heads: head_dim = hidden_size // num_attention_heads else: - raise ValueError(f"Cannot determine head_dim: hidden_size={hidden_size}, num_attention_heads={num_attention_heads}") + raise ValueError( + f"Cannot determine head_dim: hidden_size={hidden_size}, num_attention_heads={num_attention_heads}" + ) qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) if qk_nope_head_dim is not None and qk_rope_head_dim is not None: @@ -719,13 +728,13 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A "prefix_lens": prefix_lens_tensor, # For RoPE offset calculation "actual_processed_lengths": actual_processed_lengths_tensor, # For correct logit selection } - + # VLM support: collect pixel_values and image metadata for first peer - if self.is_first_peer and hasattr(self.model_shard, 'is_vlm') and self.model_shard.is_vlm: + if self.is_first_peer and hasattr(self.model_shard, "is_vlm") and self.model_shard.is_vlm: pixel_values_list = [] image_grid_thw_list = [] has_vlm_inputs = False - + for req in batched_requests: if req.vlm_inputs is not None and req.vlm_inputs.has_images(): has_vlm_inputs = True @@ -734,7 +743,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A image_grid_thw_list.append(req.vlm_inputs.image_grid_thw) else: pixel_values_list.append(None) - + if has_vlm_inputs: # For now, we only support single-image batching where all requests # in the batch either have images or don't have images @@ -750,7 +759,7 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A [mx.array(g) for g in image_grid_thw_list], axis=0 ) logger.debug(f"VLM batch: pixel_values shape={ret['pixel_values'].shape}") - + logger.debug(f"Prepared MLX prefill batch (size={batch_size})") return ret diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 6c94d372..e90c58fb 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -7,8 +7,8 @@ import torch from sglang.srt.environ import envs -from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.lora.lora_registry import LoRARef +from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache @@ -142,8 +142,8 @@ def __init__( logger.debug( f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) - self.model_runner, self.config, self.tokenizer, self.processor = initialize_sgl_model_runner( - **model_runner_params + self.model_runner, self.config, self.tokenizer, self.processor = ( + initialize_sgl_model_runner(**model_runner_params) ) logger.debug( f"SGLang model runner initialized. num_layers={self.config.get('num_hidden_layers')}" diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index c69e5d75..aee512f3 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -3,10 +3,9 @@ """ from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Type import mlx.core as mx -import numpy as np from mlx import nn from mlx_lm.models.base import BaseModelArgs @@ -18,22 +17,22 @@ class VisionConfig: """Dynamic configuration for vision models in VLM. - + This class dynamically accepts all parameters from config.json's vision_config, making it compatible with different VLM architectures (Qwen-VL, LLaVA, etc.). """ - + def __init__(self, **kwargs): # Set all provided parameters as attributes for key, value in kwargs.items(): setattr(self, key, value) - + # Set common defaults if not provided - if not hasattr(self, 'model_type'): + if not hasattr(self, "model_type"): self.model_type = "clip_vision_model" - if not hasattr(self, 'hidden_size'): + if not hasattr(self, "hidden_size"): self.hidden_size = 1024 - + @classmethod def from_dict(cls, params: Dict[str, Any]) -> "VisionConfig": """Create VisionConfig from a dictionary, accepting all parameters.""" @@ -42,12 +41,13 @@ def from_dict(cls, params: Dict[str, Any]) -> "VisionConfig": return cls(**params) -@dataclass +@dataclass class InputEmbeddingsOutput: """Output from get_input_embeddings method.""" + inputs_embeds: mx.array attention_mask: Optional[mx.array] = None - + def to_dict(self) -> Dict[str, Any]: return { "inputs_embeds": self.inputs_embeds, @@ -61,7 +61,7 @@ class ShardedModel(nn.Module): Assumes self.layers are composed of modules (e.g., TransformerBlocks) that internally use ParallaxAttention and their __call__ method returns (hidden_state, new_k_for_layer, new_v_for_layer). - + Supports VLM (Vision Language Models) by optionally loading vision_tower and multi_modal_projector on the first shard. """ @@ -99,7 +99,7 @@ def __init__( self.is_first_shard = start_layer == 0 self.is_last_shard = end_layer == config.num_hidden_layers self.n_layers_in_shard = end_layer - start_layer - + # VLM configuration self.is_vlm = vision_config is not None and vision_tower_class is not None self.vision_config = VisionConfig.from_dict(vision_config) if vision_config else None @@ -111,10 +111,12 @@ def __init__( self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) if has_norm_in: self.norm_in = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - + # Initialize vision components for VLM on first shard if self.is_vlm: - logger.info(f"Initializing VLM components: vision_tower ({self.vision_config.model_type})") + logger.info( + f"Initializing VLM components: vision_tower ({self.vision_config.model_type})" + ) self.vision_tower = vision_tower_class(self.vision_config) # Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel # In these cases, multi_modal_projector_class can be None @@ -122,7 +124,9 @@ def __init__( self.multi_modal_projector = multi_modal_projector_class(config) else: self.multi_modal_projector = None - logger.info("No separate projector class - projector is integrated into VisionModel") + logger.info( + "No separate projector class - projector is integrated into VisionModel" + ) else: self.vision_tower = None self.multi_modal_projector = None @@ -143,7 +147,7 @@ def __init__( else: self.norm = None self.lm_head = None - + def shard_layers(self): group = mx.distributed.init() tp_size = group.size() @@ -156,7 +160,7 @@ def shard_layers(self): f"Model {layer.__class__.__name__} does not have a shard method, does not support tensor parallelism" ) exit(1) - + def get_input_embeddings( self, input_ids: mx.array, @@ -164,39 +168,39 @@ def get_input_embeddings( **kwargs, ) -> InputEmbeddingsOutput: """Get input embeddings, optionally with vision features merged in. - + This method handles: 1. Text-only inputs: Simply embed tokens 2. VLM inputs: Embed tokens, encode images, and merge vision features - + Args: input_ids: (batch, seq_len) Token IDs pixel_values: (batch, C, H, W) or (num_patches, C, H, W) Image pixel values **kwargs: Additional arguments (e.g., image_grid_thw for Qwen2-VL) - + Returns: InputEmbeddingsOutput with merged embeddings """ if not self.is_first_shard: raise ValueError("get_input_embeddings should only be called on the first shard") - + # Get text embeddings inputs_embeds = self.embed_tokens(input_ids) - + # If no images or not a VLM, return text embeddings directly if pixel_values is None or not self.is_vlm: return InputEmbeddingsOutput(inputs_embeds=inputs_embeds) - + # Process vision features image_features = self._encode_images(pixel_values, **kwargs) - + # Merge image features with text embeddings final_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) - + return InputEmbeddingsOutput(inputs_embeds=final_embeds) - + def _encode_images( self, pixel_values: mx.array, @@ -204,29 +208,31 @@ def _encode_images( **kwargs, ) -> mx.array: """Encode images through vision tower and projector. - + Args: pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W) image_grid_thw: Grid size (T, H, W) for Qwen-VL models **kwargs: Additional model-specific arguments - + Returns: Projected image features ready to be merged with text embeddings """ if self.vision_tower is None: raise ValueError("Vision tower not initialized for this model") - + # Check if this is a Qwen-VL style model (needs grid_thw) - model_type = getattr(self.vision_config, 'model_type', '') if self.vision_config else '' - is_qwen_vl = 'qwen' in model_type.lower() and 'vl' in model_type.lower() - + model_type = getattr(self.vision_config, "model_type", "") if self.vision_config else "" + is_qwen_vl = "qwen" in model_type.lower() and "vl" in model_type.lower() + # Ensure correct dtype - if hasattr(self.vision_tower, 'patch_embed') and hasattr(self.vision_tower.patch_embed, 'proj'): + if hasattr(self.vision_tower, "patch_embed") and hasattr( + self.vision_tower.patch_embed, "proj" + ): target_dtype = self.vision_tower.patch_embed.proj.weight.dtype pixel_values = pixel_values.astype(target_dtype) else: pixel_values = pixel_values.astype(self.dtype) - + # Get vision features from vision tower if is_qwen_vl and image_grid_thw is not None: # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> (hidden_states, deepstack_features) @@ -243,9 +249,9 @@ def _encode_images( if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: # NCHW -> NHWC pixel_values = pixel_values.transpose(0, 2, 3, 1) - + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - + # Handle different output formats if isinstance(vision_outputs, tuple): # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states) @@ -270,7 +276,7 @@ def _encode_images( else: # Direct hidden state output selected_features = vision_outputs - + # Project to language model dimension if projector exists # Qwen-VL models have projection built into VisionModel's merger if self.multi_modal_projector is not None: @@ -278,9 +284,9 @@ def _encode_images( else: # VisionModel already outputs projected features image_features = selected_features - + return image_features - + def _merge_input_ids_with_image_features( self, image_features: mx.array, @@ -288,69 +294,71 @@ def _merge_input_ids_with_image_features( input_ids: mx.array, ) -> mx.array: """Merge image features into input embeddings at image token positions. - + This replaces placeholder tokens with actual image feature embeddings. - + Args: image_features: (num_images, num_patches, hidden_dim) or (total_patches, hidden_dim) inputs_embeds: (batch, seq_len, hidden_dim) Text embeddings input_ids: (batch, seq_len) Token IDs for finding image positions - + Returns: Merged embeddings with image features inserted at image token positions """ if self.image_token_index is None: logger.warning("image_token_index not set, cannot merge image features") return inputs_embeds - + batch_size, seq_len, hidden_dim = inputs_embeds.shape - + # Find positions of image tokens - image_positions = (input_ids == self.image_token_index) - + image_positions = input_ids == self.image_token_index + # Flatten image features if needed if image_features.ndim == 3: # (num_images, num_patches, dim) -> (total_patches, dim) image_features = image_features.reshape(-1, image_features.shape[-1]) - + # Cast image features to match embedding dtype image_features = image_features.astype(inputs_embeds.dtype) - + # Process each batch item batch_outputs = [] feature_start_idx = 0 - + for batch_idx in range(batch_size): batch_mask = image_positions[batch_idx] num_positions = int(mx.sum(batch_mask).item()) - + if num_positions > 0: # Extract features for this batch - batch_features = image_features[feature_start_idx:feature_start_idx + num_positions] - + batch_features = image_features[ + feature_start_idx : feature_start_idx + num_positions + ] + if batch_features.shape[0] != num_positions: raise ValueError( f"Number of image token positions ({num_positions}) does not match " f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}" ) - + # Create indices for gathering cumsum = mx.cumsum(batch_mask.astype(mx.int32)) feature_indices = mx.where(batch_mask, cumsum - 1, 0) - + # Gather features and create merged output gathered_features = batch_features[feature_indices] batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1) batch_output = mx.where( batch_mask_expanded, gathered_features, inputs_embeds[batch_idx] ) - + feature_start_idx += num_positions else: batch_output = inputs_embeds[batch_idx] - + batch_outputs.append(batch_output) - + return mx.stack(batch_outputs, axis=0) def logits_to_tokens( @@ -405,7 +413,7 @@ def __call__( ) -> mx.array: """ Forward pass through the sharded model. - + Args: h_or_tokens: (batch, target_len_padded, D) or (batch, target_len_padded) for prefill, @@ -421,7 +429,7 @@ def __call__( inputs_embeds: (batch, seq_len, hidden_dim) Pre-computed embeddings. If provided, skips embedding and vision processing. **kwargs: Additional model-specific arguments (e.g., image_grid_thw). - + Returns: For last shard: logits (batch, seq_len, vocab_size) For other shards: hidden states (batch, seq_len, hidden_dim) @@ -436,7 +444,7 @@ def __call__( else: if self.embed_tokens is None: raise ValueError("embed_tokens is None for the first shard.") - + # Check if we need to process vision inputs if pixel_values is not None and self.is_vlm: # Use get_input_embeddings for VLM processing @@ -445,7 +453,7 @@ def __call__( else: # Standard text embedding h = self.embed_tokens(h) - + if self.has_norm_in and self.norm_in: h = self.norm_in(h) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index bd6b1e16..2ecce2aa 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -59,11 +59,9 @@ """ import uuid -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union - -import numpy as np +from typing import Any, Dict, List, Optional from parallax.server.sampling.sampling_params import SamplingParams from parallax_utils.logging_config import get_logger @@ -74,38 +72,39 @@ @dataclass class VLMInputs: """Container for Vision Language Model inputs. - + This is used to pass image/video data through the pipeline. Only the first peer needs to process pixel_values; subsequent peers receive pre-computed image embeddings merged into hidden_states. """ + # Preprocessed image tensor, shape varies by model: # - LLaVA: (num_images, C, H, W) or (num_patches, C, patch_H, patch_W) # - Qwen-VL: (num_patches, C, patch_H, patch_W) with temporal dim for video # Can be numpy array or PyTorch tensor - mx.array() can convert both pixel_values: Optional[Any] = None - + # For models with dynamic resolution (e.g., Qwen2-VL): # Tuple of (temporal, height, width) grid sizes for each image # Shape: (num_images, 3) where each row is (t, h, w) # Can be numpy array or PyTorch tensor image_grid_thw: Optional[Any] = None - + # Number of image tokens per image (for variable-length image tokens) image_token_counts: Optional[List[int]] = None - + # Original image sizes before preprocessing (height, width) # Useful for models that need aspect ratio information image_sizes: Optional[List[tuple]] = None - + # Whether images have been processed into embeddings # (set to True after first peer processes images) images_processed: bool = False - + def has_images(self) -> bool: """Check if this request contains image inputs.""" return self.pixel_values is not None and len(self.pixel_values) > 0 - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization.""" return { @@ -115,7 +114,7 @@ def to_dict(self) -> Dict[str, Any]: "image_sizes": self.image_sizes, "images_processed": self.images_processed, } - + @classmethod def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["VLMInputs"]: """Create from dictionary.""" @@ -174,10 +173,10 @@ def __init__( self.lora_id: Optional[str] = None self.lora_path = lora_path self.multimodal_params = multimodal_params - + # VLM support: structured container for vision inputs self.vlm_inputs = vlm_inputs - + @property def has_images(self) -> bool: """Check if this request contains image inputs.""" diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 72f0235b..d0a1a696 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -58,11 +58,11 @@ def _get_vlm_classes( ) -> Tuple[Optional[Type[nn.Module]], Optional[Type[nn.Module]], Optional[Dict[str, Any]]]: """ Get VLM-specific classes for a given model type. - + Args: model_type: The model type from config.json config: Full model configuration - + Returns: Tuple of (vision_tower_class, projector_class, vision_config) Returns (None, None, None) if not a VLM model @@ -70,13 +70,13 @@ def _get_vlm_classes( vision_config = config.get("vision_config") if vision_config is None: return None, None, None - + try: # Default: load VisionModel from mlx_vlm.models.{model_type} vision_module_path = f"mlx_vlm.models.{model_type}" vision_module = importlib.import_module(vision_module_path) vision_tower_class = getattr(vision_module, "VisionModel") - + # Check if this model needs a separate projector projector_class = None if model_type in VLM_SPECIAL_PROJECTOR_MAP: @@ -86,9 +86,9 @@ def _get_vlm_classes( logger.info(f"Loaded VLM classes for {model_type}: VisionModel + {proj_class_name}") else: logger.info(f"Loaded VLM classes for {model_type}: VisionModel (projector integrated)") - + return vision_tower_class, projector_class, vision_config - + except (ImportError, AttributeError) as e: logger.warning(f"Failed to load VLM classes for {model_type}: {e}") return None, None, None @@ -156,7 +156,9 @@ def register_block_class(self): except ImportError as e: # Log more details for import errors (often missing dependencies) - logger.warning(f"Failed to import {model_file.name}: {e} (install required dependencies)") + logger.warning( + f"Failed to import {model_file.name}: {e} (install required dependencies)" + ) except Exception as e: logger.warning(f"Failed to load model from {model_file}: {e}") @@ -330,7 +332,7 @@ def load( # For VLM models, use text_config for ModelArgs and map to base model type config_for_args = config model_class_type = model_type - + if model_type in VLM_TEXT_CONFIG_MAP: # VLM models have text_config containing the language model config text_config = config.get("text_config", {}) @@ -341,8 +343,10 @@ def load( if "num_hidden_layers" not in config and "num_hidden_layers" in text_config: config["num_hidden_layers"] = text_config["num_hidden_layers"] model_class_type = VLM_TEXT_CONFIG_MAP[model_type] - logger.info(f"VLM model {model_type} using {model_class_type} ModelArgs with text_config") - + logger.info( + f"VLM model {model_type} using {model_class_type} ModelArgs with text_config" + ) + num_hidden_layers = config.get("num_hidden_layers", 0) current_start_layer = self.start_layer if self.start_layer is not None else 0 current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers @@ -358,7 +362,9 @@ def load( model_args = model_args_class.from_dict(config_for_args) except (ImportError, AttributeError) as e: - raise ValueError(f"Failed to load architecture for model_type '{model_type}' (using {model_class}).") from e + raise ValueError( + f"Failed to load architecture for model_type '{model_type}' (using {model_class})." + ) from e dtype = getattr(mx, config.get("torch_dtype", "bfloat16")) @@ -368,23 +374,23 @@ def load( model_id = model_id.split("/")[-1] else: # If it's already a clean name or a local path (take basename) model_id = pathlib.Path(model_id).name - + # Check for VLM model and get vision classes vision_tower_class, projector_class, vision_config = _get_vlm_classes(model_type, config) is_vlm = vision_config is not None and vision_tower_class is not None - + # Get VLM-specific config parameters image_token_index = config.get("image_token_index") or config.get("image_token_id") vision_feature_layer = config.get("vision_feature_layer", -2) vision_feature_select_strategy = config.get("vision_feature_select_strategy", "default") - + if is_vlm: logger.info( f"Detected VLM model: {model_type}, " f"image_token_index={image_token_index}, " f"vision_feature_layer={vision_feature_layer}" ) - + model_shard = ShardedModel( config=model_args, model_id=model_id, @@ -430,7 +436,7 @@ def load( # Instead of loading all weights, we iterate through files and keys, # loading only what we need. shard_weights = {} - + # Layer key prefixes to check (in order of priority) # Different model formats use different key prefixes: # - language_model.model.layers.X (mlx-vlm converted format) @@ -439,18 +445,18 @@ def load( layer_key_prefixes = [ ("language_model.model.layers.", 3), # mlx-vlm style: parts[3] is layer index ("model.language_model.layers.", 3), # HF VLM style: parts[3] is layer index - ("model.layers.", 2), # Standard style: parts[2] is layer index + ("model.layers.", 2), # Standard style: parts[2] is layer index ] - + # VLM weight prefixes to load on first shard vlm_weight_prefixes = [ "vision_tower.", "vision_model.", - "visual.", # Qwen-VL style + "visual.", # Qwen-VL style "multi_modal_projector.", "mm_projector.", ] - + # Get tie_word_embeddings config (check both root and text_config for VLM) tie_word_embeddings = _get_config_value(config, "tie_word_embeddings", False) @@ -484,7 +490,7 @@ def load( if model_shard.is_last_shard and tie_word_embeddings: lm_head_key = remapped_key.replace("embed_tokens", "lm_head") shard_weights[lm_head_key] = f[key] - + elif model_shard.is_last_shard: # Final norm: Various formats # - language_model.model.norm.* (mlx-vlm converted) @@ -492,9 +498,9 @@ def load( # - model.norm.* (standard) if ".norm." in key or key.endswith(".norm.weight"): is_final_norm = ( - "language_model.model.norm" in key or - "language_model.norm" in key or - (key.startswith("model.norm") and "layers" not in key) + "language_model.model.norm" in key + or "language_model.norm" in key + or (key.startswith("model.norm") and "layers" not in key) ) if is_final_norm: is_needed = True @@ -514,12 +520,18 @@ def load( elif tie_word_embeddings and "embed_tokens" in key: is_needed = True if "language_model.model.embed_tokens" in key: - remapped_key = key.replace("language_model.model.", "").replace("embed_tokens", "lm_head") + remapped_key = key.replace("language_model.model.", "").replace( + "embed_tokens", "lm_head" + ) elif "language_model.embed_tokens" in key: - remapped_key = key.split("language_model.")[-1].replace("embed_tokens", "lm_head") + remapped_key = key.split("language_model.")[-1].replace( + "embed_tokens", "lm_head" + ) else: - remapped_key = key.replace("model.", "", 1).replace("embed_tokens", "lm_head") - + remapped_key = key.replace("model.", "", 1).replace( + "embed_tokens", "lm_head" + ) + # VLM: Load vision tower and projector weights on first shard if model_shard.is_first_shard and is_vlm: for prefix in vlm_weight_prefixes: @@ -533,7 +545,7 @@ def load( # Keep as vision_tower.* or visual.* (remove model. prefix) remapped_key = key.replace("model.", "", 1) break - + # Check layer keys with multiple prefix patterns if not is_needed: for layer_prefix, layer_idx_pos in layer_key_prefixes: @@ -545,8 +557,10 @@ def load( is_needed = True local_layer_idx = layer_idx - current_start_layer # Remap to layers.{local_idx}.{rest} - rest_parts = parts[layer_idx_pos + 1:] - remapped_key = f"layers.{local_layer_idx}.{'.'.join(rest_parts)}" + rest_parts = parts[layer_idx_pos + 1 :] + remapped_key = ( + f"layers.{local_layer_idx}.{'.'.join(rest_parts)}" + ) break except (ValueError, IndexError): continue @@ -598,8 +612,10 @@ def class_predicate(p, m): ) # Log weight keys before loading - logger.info(f"Loading {len(shard_weights)} weights. Sample keys: {list(shard_weights.keys())[:20]}") - + logger.info( + f"Loading {len(shard_weights)} weights. Sample keys: {list(shard_weights.keys())[:20]}" + ) + # Try strict mode first to catch any mismatch, then fall back to non-strict try: model_shard.load_weights(list(shard_weights.items()), strict=True) @@ -610,10 +626,22 @@ def class_predicate(p, m): # Log VLM-specific weight loading info if is_vlm and model_shard.is_first_shard: - vlm_weight_count = sum(1 for k in shard_weights.keys() - if any(k.startswith(p) for p in ["vision_tower", "vision_model", "visual", "multi_modal_projector", "mm_projector"])) + vlm_weight_count = sum( + 1 + for k in shard_weights.keys() + if any( + k.startswith(p) + for p in [ + "vision_tower", + "vision_model", + "visual", + "multi_modal_projector", + "mm_projector", + ] + ) + ) logger.info(f"Loaded {vlm_weight_count} VLM weights (vision_tower + projector)") - + logger.info(f"Total weights loaded: {len(shard_weights)}") shard_weights.clear() @@ -622,7 +650,7 @@ def class_predicate(p, m): # Synchronize processes to avoid timeout mx.eval(mx.distributed.all_sum(mx.array(1.0))) model_shard.eval() - + vlm_info = f", VLM={is_vlm}" if is_vlm else "" logger.info( "Successfully loaded model shard (layers [%d-%d)%s), memory usage: %.3f GB", diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 1369180d..ba133fad 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -3,10 +3,10 @@ """ from types import SimpleNamespace -from typing import List, Optional, Any +from typing import Any, List, Optional import torch -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch, MultimodalInputs +from sglang.srt.managers.schedule_batch import MultimodalInputs, Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -54,7 +54,7 @@ def transform_requests_to_sglang( ) -> List[Req]: reqs = [] mm_config = hf_config or {} - + for old_req in old_requests: sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( @@ -64,12 +64,12 @@ def transform_requests_to_sglang( sampling_params=sampling_params, lora_id=old_req.lora_id, ) - + if old_req.multimodal_params is not None: try: if "mm_items" in old_req.multimodal_params: req.multimodal_inputs = MultimodalInputs.from_dict(old_req.multimodal_params) - + elif "images" in old_req.multimodal_params and processor is not None: image_urls = old_req.multimodal_params["images"] multimodal_inputs, padded_input_ids = process_multimodal_request( @@ -82,7 +82,7 @@ def transform_requests_to_sglang( if multimodal_inputs is not None: req.multimodal_inputs = multimodal_inputs req.origin_input_ids = padded_input_ids - + else: logger.warning( f"Assigning raw multimodal_params to req.multimodal_inputs. " diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index ee06ca7e..99c80d9f 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -307,6 +307,7 @@ def initialize_sgl_model_runner( processor = None try: from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) logger.info(f"Loaded processor for model {model_path}") except Exception as e: diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 160e2d96..0840fe97 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -2,27 +2,31 @@ Multimodal utilities for SGLang backend. """ -from typing import Any, List, Optional, Tuple -from io import BytesIO import logging +from io import BytesIO +from typing import Any, List, Optional, Tuple import requests import torch from PIL import Image -from sglang.srt.managers.schedule_batch import MultimodalInputs, MultimodalDataItem, Modality from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) logger = logging.getLogger(__name__) def load_image(url: Any) -> Image.Image: import base64 - + if isinstance(url, dict): url = url.get("url") if not isinstance(url, str): raise ValueError(f"Unsupported image url type: {type(url)}") - + if url.startswith("http"): response = requests.get(url, timeout=10) response.raise_for_status() @@ -37,17 +41,17 @@ def load_image(url: Any) -> Image.Image: def process_images( - image_urls: List[Any], - processor: Any, + image_urls: List[Any], + processor: Any, input_text: str = "", mm_config: Optional[dict] = None, ) -> tuple[List[MultimodalDataItem], Optional[torch.Tensor], List[int]]: if not image_urls: return [], None, [] - + mm_items = [] images = [] - + for url in image_urls: try: image = load_image(url) @@ -55,42 +59,42 @@ def process_images( except Exception as e: logger.exception(f"Failed to load image {url}: {e}") continue - + if not images: return [], None, [] - + try: inputs = processor(text=input_text, images=images, return_tensors="pt") - + if inputs is None: logger.error("Processor returned None") return [], None, [] - + pixel_values = inputs.get("pixel_values") if pixel_values is None: logger.error("Processor output missing pixel_values") return [], None, [] - + expanded_input_ids = inputs.get("input_ids") if expanded_input_ids is not None: expanded_input_ids = expanded_input_ids.flatten().tolist() else: expanded_input_ids = [] - + logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") - + image_grid_thw = inputs.get("image_grid_thw") image_sizes = inputs.get("image_sizes") - + model_specific_data = {} if image_grid_thw is not None: model_specific_data["image_grid_thw"] = image_grid_thw if image_sizes is not None: model_specific_data["image_sizes"] = image_sizes - + if image_grid_thw is not None and len(image_grid_thw) == len(images): num_images = len(images) - + patches_per_image = [] for i in range(num_images): grid = image_grid_thw[i] @@ -99,13 +103,13 @@ def process_images( else: num_patches = int(torch.prod(torch.tensor(grid)).item()) patches_per_image.append(num_patches) - + patch_start = 0 for i in range(num_images): num_patches = patches_per_image[i] - item_pixel_values = pixel_values[patch_start:patch_start + num_patches] - item_grid_thw = image_grid_thw[i:i+1] - + item_pixel_values = pixel_values[patch_start : patch_start + num_patches] + item_grid_thw = image_grid_thw[i : i + 1] + item = MultimodalDataItem( modality=Modality.IMAGE, feature=item_pixel_values, @@ -122,26 +126,26 @@ def process_images( model_specific_data=model_specific_data, ) mm_items.append(item) - + return mm_items, image_grid_thw, expanded_input_ids - + except Exception as e: logger.exception(f"Failed to process images: {e}") return [], None, [] def get_image_token_offsets( - input_ids: List[int], + input_ids: List[int], image_token_id: Optional[int], vision_start_id: Optional[int] = None, vision_end_id: Optional[int] = None, ) -> List[Tuple[int, int]]: offsets = [] - + if vision_start_id is not None and vision_end_id is not None: start_indices = [i for i, tok in enumerate(input_ids) if tok == vision_start_id] end_indices = [i for i, tok in enumerate(input_ids) if tok == vision_end_id] - + for start, end in zip(start_indices, end_indices): if start < end: offsets.append((start + 1, end - 1)) @@ -156,7 +160,7 @@ def get_image_token_offsets( start = None if start is not None: offsets.append((start, len(input_ids) - 1)) - + return offsets @@ -167,24 +171,25 @@ def compute_mrope_positions( ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: import logging + logger = logging.getLogger(__name__) - + seq_len = len(input_ids) - + def get_default_positions(): positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(3, 1) delta = torch.zeros(1, dtype=torch.long) return positions, delta - + if image_grid_thw is None: return get_default_positions() - + image_token_id = mm_config.get("image_token_id") vision_start_id = mm_config.get("vision_start_token_id") video_token_id = mm_config.get("video_token_id") model_type = mm_config.get("model_type") vision_config = mm_config.get("vision_config", {}) - + spatial_merge_size = ( vision_config.get("spatial_merge_size") if isinstance(vision_config, dict) else None ) @@ -198,12 +203,12 @@ def get_default_positions(): f"spatial_merge_size={spatial_merge_size}. Using default positions." ) return get_default_positions() - + try: from sglang.srt.layers.rotary_embedding import MRotaryEmbedding - + input_ids_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) - + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=spatial_merge_size, image_token_id=image_token_id, @@ -214,10 +219,10 @@ def get_default_positions(): image_grid_thw=image_grid_thw, tokens_per_second=tokens_per_second, ) - - mrope_positions = mrope_positions.squeeze(1) + + mrope_positions = mrope_positions.squeeze(1) return mrope_positions, mrope_position_delta - + except Exception as e: logger.warning(f"Failed to compute mrope_positions: {e}. Using default positions.") return get_default_positions() @@ -230,27 +235,26 @@ def prepare_sglang_multimodal_inputs( input_ids: Optional[List[int]] = None, ) -> MultimodalInputs: mm_config = mm_config or {} - + # Extract token IDs from config image_token_id = mm_config.get("image_token_id") vision_start_id = mm_config.get("vision_start_token_id") vision_end_id = mm_config.get("vision_end_token_id") video_token_id = mm_config.get("video_token_id") audio_token_id = mm_config.get("audio_token_id") - for item in mm_items: if item.pad_value is None: item.set_pad_value() - + mrope_positions = None mrope_position_delta = None - + if input_ids is not None: mrope_positions, mrope_position_delta = compute_mrope_positions( input_ids, image_grid_thw, mm_config ) - + return MultimodalInputs( mm_items=mm_items, im_token_id=image_token_id, @@ -277,44 +281,44 @@ def process_multimodal_request( logger.debug(f"Decoded input text (length={len(input_text)}): {input_text[:100]}...") except Exception as e: logger.warning(f"Failed to decode input_ids: {e}") - + mm_items, image_grid_thw, expanded_input_ids = process_images( - image_urls, - processor, + image_urls, + processor, input_text=input_text, mm_config=mm_config, ) - + if not mm_items: return None, list(input_ids) - + if expanded_input_ids and len(expanded_input_ids) > len(input_ids): logger.debug(f"Using expanded input_ids: {len(input_ids)} -> {len(expanded_input_ids)}") input_ids_for_offsets = expanded_input_ids else: input_ids_for_offsets = list(input_ids) - + image_token_id = mm_config.get("image_token_id") vision_start_id = mm_config.get("vision_start_token_id") vision_end_id = mm_config.get("vision_end_token_id") - + offsets = get_image_token_offsets( input_ids_for_offsets, image_token_id, vision_start_id, vision_end_id, ) - + if len(offsets) == len(mm_items): for item, offset in zip(mm_items, offsets): item.offsets = [offset] elif len(offsets) > 0: for item in mm_items: item.offsets = offsets - + for item in mm_items: item.set_pad_value() - + combined_grid_thw = None if image_grid_thw is not None: combined_grid_thw = image_grid_thw @@ -326,23 +330,20 @@ def process_multimodal_request( grids.append(grid) if grids: combined_grid_thw = torch.cat(grids, dim=0) - + multimodal_inputs = prepare_sglang_multimodal_inputs( mm_items=mm_items, image_grid_thw=combined_grid_thw, mm_config=mm_config, input_ids=input_ids_for_offsets, ) - + padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() - padded_input_ids = padding_pattern.pad_input_tokens( - input_ids_for_offsets, - multimodal_inputs - ) - + padded_input_ids = padding_pattern.pad_input_tokens(input_ids_for_offsets, multimodal_inputs) + logger.debug( f"Successfully processed {len(mm_items)} images, " f"offsets={offsets}, padded_input_ids_len={len(padded_input_ids)}" ) - + return multimodal_inputs, padded_input_ids diff --git a/src/parallax/utils/vlm_utils.py b/src/parallax/utils/vlm_utils.py index 0b2e9092..87ca2bf2 100644 --- a/src/parallax/utils/vlm_utils.py +++ b/src/parallax/utils/vlm_utils.py @@ -6,7 +6,7 @@ import base64 from io import BytesIO -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -23,7 +23,6 @@ def _check_pil(): global _PIL_AVAILABLE if _PIL_AVAILABLE is None: try: - from PIL import Image _PIL_AVAILABLE = True except ImportError: _PIL_AVAILABLE = False @@ -34,7 +33,6 @@ def _check_requests(): global _REQUESTS_AVAILABLE if _REQUESTS_AVAILABLE is None: try: - import requests _REQUESTS_AVAILABLE = True except ImportError: _REQUESTS_AVAILABLE = False @@ -44,56 +42,61 @@ def _check_requests(): def load_image(source: Any) -> "Image.Image": """ Load an image from various sources. - + Supports: - URL (http/https) - Local file path - Base64 encoded data URL - PIL Image object (pass through) - Dict with "url" key - + Args: source: Image source (URL, path, base64, or PIL Image) - + Returns: PIL Image in RGB format - + Raises: ImportError: If PIL is not installed ValueError: If source type is not supported """ if not _check_pil(): - raise ImportError("PIL (Pillow) is required for image loading. Install with: pip install Pillow") - + raise ImportError( + "PIL (Pillow) is required for image loading. Install with: pip install Pillow" + ) + from PIL import Image - + # Handle dict with url key if isinstance(source, dict): source = source.get("url") - + # Pass through PIL Image if isinstance(source, Image.Image): return source.convert("RGB") - + if not isinstance(source, str): raise ValueError(f"Unsupported image source type: {type(source)}") - + # Load from URL if source.startswith("http://") or source.startswith("https://"): if not _check_requests(): - raise ImportError("requests is required for URL loading. Install with: pip install requests") + raise ImportError( + "requests is required for URL loading. Install with: pip install requests" + ) import requests + response = requests.get(source, timeout=30) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") - + # Load from base64 data URL if source.startswith("data:image"): # Format: data:image/png;base64, header, encoded = source.split(",", 1) image_data = base64.b64decode(encoded) return Image.open(BytesIO(image_data)).convert("RGB") - + # Load from local file path return Image.open(source).convert("RGB") @@ -106,13 +109,13 @@ def process_images_for_vlm( ) -> Dict[str, Any]: """ Process images for VLM input using a HuggingFace processor. - + Args: images: List of image sources (URLs, paths, base64, or PIL Images) processor: HuggingFace processor (e.g., Qwen2VLProcessor, LlavaProcessor) text: Text prompt to process alongside images model_type: Optional model type for model-specific processing - + Returns: Dictionary containing: - pixel_values: numpy array of processed images @@ -122,7 +125,7 @@ def process_images_for_vlm( """ if not images: return {} - + # Load all images pil_images = [] image_sizes = [] @@ -134,11 +137,11 @@ def process_images_for_vlm( except Exception as e: logger.warning(f"Failed to load image: {e}") continue - + if not pil_images: logger.warning("No images were successfully loaded") return {} - + # Process with HuggingFace processor try: inputs = processor( @@ -149,9 +152,9 @@ def process_images_for_vlm( except Exception as e: logger.error(f"Image processing failed: {e}") return {} - + result = {"image_sizes": image_sizes} - + # Extract pixel_values if hasattr(inputs, "pixel_values"): pixel_values = inputs.pixel_values @@ -163,7 +166,7 @@ def process_images_for_vlm( if hasattr(pixel_values, "numpy"): pixel_values = pixel_values.numpy() result["pixel_values"] = pixel_values - + # Extract image_grid_thw (for Qwen-VL style models) if hasattr(inputs, "image_grid_thw"): grid_thw = inputs.image_grid_thw @@ -175,7 +178,7 @@ def process_images_for_vlm( if hasattr(grid_thw, "numpy"): grid_thw = grid_thw.numpy() result["image_grid_thw"] = grid_thw - + # Extract input_ids if available (some processors expand image tokens) if hasattr(inputs, "input_ids"): input_ids = inputs.input_ids @@ -187,12 +190,12 @@ def process_images_for_vlm( if hasattr(input_ids, "numpy"): input_ids = input_ids.numpy() result["input_ids"] = input_ids.flatten().tolist() - + logger.debug( f"Processed {len(pil_images)} images, " f"pixel_values shape: {result.get('pixel_values', np.array([])).shape}" ) - + return result @@ -204,34 +207,34 @@ def create_vlm_inputs_from_request( ) -> Optional["VLMInputs"]: """ Create VLMInputs object from image URLs and processor. - + This is a convenience function that combines image processing and VLMInputs creation. - + Args: image_urls: List of image URLs or paths processor: HuggingFace processor text: Text prompt model_type: Optional model type - + Returns: VLMInputs object or None if processing fails """ from parallax.server.request import VLMInputs - + if not image_urls: return None - + processed = process_images_for_vlm( images=image_urls, processor=processor, text=text, model_type=model_type, ) - + if not processed or "pixel_values" not in processed: return None - + return VLMInputs( pixel_values=processed["pixel_values"], image_grid_thw=processed.get("image_grid_thw"), @@ -249,18 +252,18 @@ def get_image_token_count( ) -> int: """ Calculate the number of image tokens for a given image. - + Different VLM models have different token counting strategies: - LLaVA: Fixed number based on image size / patch_size - Qwen-VL: Dynamic based on image_grid_thw (temporal * height * width) - + Args: image_grid_thw: Grid sizes (temporal, height, width) for Qwen-VL style image_size: Image size (height, width) for LLaVA style patch_size: Size of each image patch merge_size: Merge factor for Qwen-VL (reduces tokens by merge_size^2) temporal_patch_size: Temporal patch size for video - + Returns: Number of image tokens """ @@ -271,12 +274,12 @@ def get_image_token_count( else: t, h, w = image_grid_thw # After merge: (t * h * w) / (merge_size^2) - return int((t * h * w) // (merge_size ** 2)) - + return int((t * h * w) // (merge_size**2)) + if image_size is not None: # LLaVA style: (H / patch_size) * (W / patch_size) h, w = image_size return (h // patch_size) * (w // patch_size) - + # Default fallback return 576 # Common default for 336x336 image with 14x14 patches diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index 1e1d9edd..bfd8a592 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -33,7 +33,7 @@ def should_include_weight_key( # Handles: model.embed_tokens, model.language_model.embed_tokens, language_model.model.embed_tokens if is_first_shard and "embed_tokens" in key: return True - + # VLM: Include vision components on first shard # Handles various naming conventions: # - vision_tower.*, model.vision_tower.* @@ -41,8 +41,11 @@ def should_include_weight_key( # - multi_modal_projector.*, mm_projector.* if is_first_shard and is_vlm: vlm_prefixes = [ - "vision_tower", "vision_model", "visual", - "multi_modal_projector", "mm_projector", + "vision_tower", + "vision_model", + "visual", + "multi_modal_projector", + "mm_projector", ] for prefix in vlm_prefixes: if key.startswith(prefix) or key.startswith(f"model.{prefix}"): @@ -53,8 +56,7 @@ def should_include_weight_key( if is_last_shard: # Check for final norm (not layer norms) if ("lm_head" in key) or ( - (".norm." in key or key.endswith(".norm.weight")) and - "layers" not in key + (".norm." in key or key.endswith(".norm.weight")) and "layers" not in key ): return True if tie_word_embeddings and "embed_tokens" in key: @@ -148,7 +150,7 @@ def filter_weight_files_by_layer_range_for_load( f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " f"for layers [{start_layer}, {end_layer})" ) - + # If filtering resulted in no files but we had input files, # fall back to original files (file naming may differ from index) if not filtered_files and weight_files: From 8b2ed0b6d43e74908c095a67eb63f8edf98e685f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 4 Feb 2026 16:39:11 +0800 Subject: [PATCH 12/36] add config get utils --- src/parallax/server/executor/base_executor.py | 44 ++--- src/parallax/utils/config_utils.py | 168 ++++++++++++++++++ 2 files changed, 180 insertions(+), 32 deletions(-) create mode 100644 src/parallax/utils/config_utils.py diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index c7de2b8c..a78eda10 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -43,6 +43,7 @@ ) from parallax.server.sampling.sampling_params import SamplingParams from parallax.server.scheduler import Scheduler +from parallax.utils.config_utils import ModelConfigAccessor from parallax.utils.shared_state import SharedState from parallax.utils.utils import get_current_device, get_device_dtype, get_zmq_socket from parallax_utils.logging_config import get_logger @@ -113,19 +114,12 @@ def __init__( # Pipe communication self.conn = conn - def _config_get(key, default=None): - if isinstance(self.config, dict): - return self.config.get(key, default) - return getattr(self.config, key, default) + # Use VLM-aware config accessor for unified config access + self._config_accessor = ModelConfigAccessor(self.config) + self.is_vlm = self._config_accessor.is_vlm - num_hidden_layers = _config_get("num_hidden_layers") - text_config = _config_get("text_config") - if num_hidden_layers is None and isinstance(text_config, dict): - num_hidden_layers = text_config.get("num_hidden_layers") - if num_hidden_layers is None: - num_hidden_layers = _config_get("n_layer") or _config_get("num_layers") - if isinstance(text_config, dict): - num_hidden_layers = text_config.get("num_hidden_layers") + # Get num_hidden_layers using unified config accessor + num_hidden_layers = self._config_accessor.get_num_hidden_layers() self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == num_hidden_layers @@ -158,25 +152,10 @@ def _config_get(key, default=None): else: self.pad_token_id = self.tokenizer.pad_token_id - self.eos_token_id = _config_get("eos_token_id") - if self.eos_token_id is None and isinstance(text_config, dict): - self.eos_token_id = text_config.get("eos_token_id") - - vision_config = _config_get("vision_config", {}) - if not isinstance(vision_config, dict): - vision_config = { - "spatial_merge_size": getattr(vision_config, "spatial_merge_size", None), - "tokens_per_second": getattr(vision_config, "tokens_per_second", None), - } - self.mm_config = { - "model_type": _config_get("model_type"), - "image_token_id": _config_get("image_token_id"), - "vision_start_token_id": _config_get("vision_start_token_id"), - "vision_end_token_id": _config_get("vision_end_token_id"), - "video_token_id": _config_get("video_token_id"), - "audio_token_id": _config_get("audio_token_id"), - "vision_config": vision_config, - } + self.eos_token_id = self._config_accessor.get_eos_token_id() + + # Build multimodal config (only meaningful for VLM models) + self.mm_config = self._config_accessor.build_mm_config() # Scheduler: derive final max_batch_size with KV constraints # Remove this for now as it's not working on gpu devices @@ -240,7 +219,8 @@ def _config_get(key, default=None): f"(layers [{self.start_layer}, {self.end_layer}), " f"tp_rank={self.tp_rank}/{self.tp_size}, " f"device={self.device}, " - f"num_shard_layers={self.num_shard_layers})" + f"num_shard_layers={self.num_shard_layers}, " + f"is_vlm={self.is_vlm})" ) @abstractmethod diff --git a/src/parallax/utils/config_utils.py b/src/parallax/utils/config_utils.py new file mode 100644 index 00000000..b13db907 --- /dev/null +++ b/src/parallax/utils/config_utils.py @@ -0,0 +1,168 @@ +""" +Configuration utilities for handling model configs. + +Provides VLM-aware config access for models with text_config/vision_config structure. +""" + +from typing import Any, List, Optional + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ModelConfigAccessor: + """ + VLM-aware model configuration accessor. + + VLM (Vision-Language Models) typically have a nested config structure: + - text_config: Contains text model parameters (num_hidden_layers, eos_token_id, etc.) + - vision_config: Contains vision encoder parameters + + This class provides unified access to config values, automatically handling + the nested structure for VLM models. + """ + + def __init__(self, config: Any): + """ + Initialize the config accessor. + + Args: + config: Model configuration (dict or object with attributes) + """ + self._config = config + self._is_vlm: Optional[bool] = None + + @property + def is_vlm(self) -> bool: + """Check if the model is a VLM (has both text_config and vision_config).""" + if self._is_vlm is None: + text_config = self._raw_get("text_config") + vision_config = self._raw_get("vision_config") + self._is_vlm = text_config is not None and vision_config is not None + return self._is_vlm + + @property + def text_config(self) -> Optional[Any]: + """Get the text_config if available.""" + return self._raw_get("text_config") + + @property + def vision_config(self) -> Optional[Any]: + """Get the vision_config if available.""" + return self._raw_get("vision_config") + + def _raw_get(self, key: str, default: Any = None) -> Any: + """ + Low-level config access without VLM-aware logic. + + Args: + key: Configuration key + default: Default value if key not found + + Returns: + Config value or default + """ + if isinstance(self._config, dict): + return self._config.get(key, default) + return getattr(self._config, key, default) + + def _get_from_subconfig(self, subconfig: Any, key: str) -> Optional[Any]: + """Get a value from a subconfig (dict or object).""" + if subconfig is None: + return None + if isinstance(subconfig, dict): + return subconfig.get(key) + return getattr(subconfig, key, None) + + def get( + self, + key: str, + default: Any = None, + fallback_keys: Optional[List[str]] = None, + prefer_text_config: bool = False, + ) -> Any: + """ + Get a configuration value with VLM-aware logic. + + For VLM models, text-related parameters (num_hidden_layers, eos_token_id, etc.) + are typically stored in 'text_config' rather than at the root level. + + Args: + key: The primary key to look up. + default: Default value if key is not found anywhere. + fallback_keys: Alternative keys to try if the primary key is not found. + prefer_text_config: If True and this is a VLM model, look in text_config first. + + Returns: + The configuration value or default. + """ + # For VLM models, prefer text_config for text-related parameters + if prefer_text_config and self.is_vlm: + value = self._get_from_subconfig(self.text_config, key) + if value is not None: + return value + + # Try primary key at root level + value = self._raw_get(key) + if value is not None: + return value + + # Try fallback keys at root level + if fallback_keys: + for fallback_key in fallback_keys: + value = self._raw_get(fallback_key) + if value is not None: + return value + + # For non-VLM models, also check text_config as fallback + # (some models might have this structure without being full VLMs) + if not self.is_vlm and prefer_text_config: + value = self._get_from_subconfig(self.text_config, key) + if value is not None: + return value + + return default + + def get_num_hidden_layers(self) -> Optional[int]: + """Get the number of hidden layers (common parameter).""" + return self.get( + "num_hidden_layers", + fallback_keys=["n_layer", "num_layers"], + prefer_text_config=True, + ) + + def get_eos_token_id(self) -> Optional[int]: + """Get the EOS token ID.""" + return self.get("eos_token_id", prefer_text_config=True) + + def build_mm_config(self) -> dict: + """ + Build the multimodal configuration dictionary. + + Returns: + Dictionary containing multimodal-related config values. + """ + vision_config_raw = self._raw_get("vision_config", {}) + + # Normalize vision_config to dict format + if vision_config_raw is None: + vision_config = {} + elif isinstance(vision_config_raw, dict): + vision_config = vision_config_raw + else: + # Convert object-style config to dict + vision_config = { + "spatial_merge_size": getattr(vision_config_raw, "spatial_merge_size", None), + "tokens_per_second": getattr(vision_config_raw, "tokens_per_second", None), + } + + return { + "model_type": self._raw_get("model_type"), + "image_token_id": self._raw_get("image_token_id"), + "vision_start_token_id": self._raw_get("vision_start_token_id"), + "vision_end_token_id": self._raw_get("vision_end_token_id"), + "video_token_id": self._raw_get("video_token_id"), + "audio_token_id": self._raw_get("audio_token_id"), + "vision_config": vision_config, + } From 4a68fd30d3fcd475da22cfe3d4739636c741d9f2 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 4 Feb 2026 16:49:20 +0800 Subject: [PATCH 13/36] refactor baseexecutor --- src/parallax/server/executor/base_executor.py | 39 ++----------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index a78eda10..4a15e81b 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -655,19 +655,10 @@ def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> return prompt def _process_vlm_request(self, rid: str, messages: list, image_urls: list): - """ - Process a VLM (multimodal) request using the VLM processor. - - The processor handles both text formatting and image preprocessing together, - ensuring proper image token insertion and expansion. - - Returns: - Tuple of (input_ids, VLMInputs) - """ + """Process a VLM (multimodal) request using the VLM processor.""" from parallax.utils.vlm_utils import load_image try: - # Load images images = [] for url in image_urls: try: @@ -681,59 +672,39 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): f"No images loaded for VLM request {rid}, falling back to text processing" ) return self._process_text_request(rid, messages, {}), None - - # Format messages for processor - # Most VLM processors expect a specific format with image placeholders formatted_messages = self._format_messages_for_vlm(messages) - # Apply chat template to get the text prompt if hasattr(self.vlm_processor, "apply_chat_template"): - # Some processors have their own chat template text_prompt = self.vlm_processor.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True, ) elif self.tokenizer.chat_template: - # Fall back to tokenizer's chat template text_prompt = self.tokenizer.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True, ) else: - # Simple fallback text_prompt = "\n".join( f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" for msg in formatted_messages ) - # Use processor to handle text + images together - # Note: Qwen processors only support return_tensors="pt" - # We keep PyTorch tensors directly - mx.array() can convert them later processor_inputs = self.vlm_processor( text=text_prompt, images=images, return_tensors="pt", ) - - # Extract input_ids input_ids = processor_inputs.get("input_ids") if input_ids is None: raise ValueError("Processor did not return input_ids") + prompt = input_ids.flatten().tolist() - # Convert to list for token IDs (need CPU for this) - if hasattr(input_ids, "tolist"): - prompt = input_ids.flatten().tolist() - else: - prompt = input_ids.flatten().tolist() - - # Extract pixel_values and image_grid_thw - keep as PyTorch tensors - # mx.array() can convert PyTorch tensors directly pixel_values = processor_inputs.get("pixel_values") image_grid_thw = processor_inputs.get("image_grid_thw") - # Create VLMInputs vlm_inputs = VLMInputs( pixel_values=pixel_values, image_grid_thw=image_grid_thw, @@ -751,7 +722,6 @@ def _process_vlm_request(self, rid: str, messages: list, image_urls: list): except Exception as e: logger.error(f"Failed to process VLM request {rid}: {e}") - # Fall back to text-only processing return self._process_text_request(rid, messages, {}), None def _format_messages_for_vlm(self, messages: list) -> list: @@ -763,16 +733,13 @@ def _format_messages_for_vlm(self, messages: list) -> list: for msg in messages: content = msg.get("content") if isinstance(content, list): - # Keep content as list format for chat template to process - # Chat templates like Qwen3-VL expect content list with image_url parts new_content = [] for part in content: if isinstance(part, dict): if part.get("type") == "text": new_content.append({"type": "text", "text": part.get("text", "")}) elif part.get("type") == "image_url": - # Mark as image for chat template - # Qwen3-VL chat template looks for 'image_url' or 'image' in content + new_content.append({"type": "image"}) elif isinstance(part, str): new_content.append({"type": "text", "text": part}) From c5f43619b25fa020ef097f028e28efb434ba400f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 11:23:19 +0800 Subject: [PATCH 14/36] update --- src/parallax/launch.py | 9 +- src/parallax/server/executor/base_executor.py | 137 +++------------- src/parallax/server/executor/mlx_executor.py | 152 ++++++++++++++++-- src/parallax/server/shard_loader.py | 11 +- src/parallax/utils/config_utils.py | 65 ++++++++ src/parallax/utils/weight_filter_utils.py | 47 ++---- 6 files changed, 238 insertions(+), 183 deletions(-) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 6959ca1d..9c52cc4f 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -26,6 +26,7 @@ from parallax.server.executor.factory import run_executor_process, stop_executor_process from parallax.server.http_server import launch_http_server, stop_http_server from parallax.server.server_args import parse_args +from parallax.utils.config_utils import get_config_value from parallax.utils.shared_state import SharedState from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port from parallax_utils.ascii_anime import display_parallax_join @@ -120,13 +121,7 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr check_latest_release() config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache) - num_layers = config.get("num_hidden_layers") - - # If not found in top level, check text_config (common in multimodal models) - if num_layers is None and "text_config" in config: - text_config = config["text_config"] - if isinstance(text_config, dict): - num_layers = text_config.get("num_hidden_layers") + num_layers = get_config_value(config, "num_hidden_layers") if args.start_layer is None: args.start_layer = 0 diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 4a15e81b..afa73af9 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -654,118 +654,27 @@ def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> return prompt - def _process_vlm_request(self, rid: str, messages: list, image_urls: list): - """Process a VLM (multimodal) request using the VLM processor.""" - from parallax.utils.vlm_utils import load_image - - try: - images = [] - for url in image_urls: - try: - img = load_image(url) - images.append(img) - except Exception as e: - logger.warning(f"Failed to load image {url}: {e}") - - if not images: - logger.warning( - f"No images loaded for VLM request {rid}, falling back to text processing" - ) - return self._process_text_request(rid, messages, {}), None - formatted_messages = self._format_messages_for_vlm(messages) - - if hasattr(self.vlm_processor, "apply_chat_template"): - text_prompt = self.vlm_processor.apply_chat_template( - formatted_messages, - tokenize=False, - add_generation_prompt=True, - ) - elif self.tokenizer.chat_template: - text_prompt = self.tokenizer.apply_chat_template( - formatted_messages, - tokenize=False, - add_generation_prompt=True, - ) - else: - text_prompt = "\n".join( - f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" - for msg in formatted_messages - ) - - processor_inputs = self.vlm_processor( - text=text_prompt, - images=images, - return_tensors="pt", - ) - input_ids = processor_inputs.get("input_ids") - if input_ids is None: - raise ValueError("Processor did not return input_ids") - prompt = input_ids.flatten().tolist() - - pixel_values = processor_inputs.get("pixel_values") - image_grid_thw = processor_inputs.get("image_grid_thw") - - vlm_inputs = VLMInputs( - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - image_sizes=[(img.height, img.width) for img in images], - images_processed=False, - ) - - logger.debug( - f"VLM request {rid}: {len(images)} images, " - f"input_ids length={len(prompt)}, " - f"pixel_values shape={pixel_values.shape if pixel_values is not None else None}" - ) + def _process_request_prompt( + self, rid: str, messages: list, image_urls: list, raw_request: Dict + ) -> Tuple[list, Optional[VLMInputs]]: + """ + Process request messages and return (input_ids, vlm_inputs). - return prompt, vlm_inputs + Subclasses can override this method to implement custom VLM processing. + Default implementation only handles text requests. - except Exception as e: - logger.error(f"Failed to process VLM request {rid}: {e}") - return self._process_text_request(rid, messages, {}), None + Args: + rid: Request ID + messages: List of message dicts + image_urls: List of image URLs extracted from messages + raw_request: Original raw request dict - def _format_messages_for_vlm(self, messages: list) -> list: - """ - Format messages for VLM processing. - Keep the original structure so chat template can handle image placeholders correctly. + Returns: + Tuple of (input_ids, vlm_inputs). vlm_inputs is None for text-only requests. """ - formatted = [] - for msg in messages: - content = msg.get("content") - if isinstance(content, list): - new_content = [] - for part in content: - if isinstance(part, dict): - if part.get("type") == "text": - new_content.append({"type": "text", "text": part.get("text", "")}) - elif part.get("type") == "image_url": - - new_content.append({"type": "image"}) - elif isinstance(part, str): - new_content.append({"type": "text", "text": part}) - formatted.append( - { - "role": msg.get("role", "user"), - "content": new_content, - } - ) - else: - formatted.append(msg) - return formatted - - def _extract_text_from_content(self, content) -> str: - """Extract text from message content (handles both string and list formats).""" - if isinstance(content, str): - return content - if isinstance(content, list): - texts = [] - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - texts.append(part.get("text", "")) - elif isinstance(part, str): - texts.append(part) - return " ".join(texts) - return str(content) + # Default: text-only processing, subclasses override for VLM support + prompt = self._process_text_request(rid, messages, raw_request) + return prompt, None def _handle_raw_request(self, raw_request: Dict): assert "messages" in raw_request, "Request did not contain messages" @@ -791,16 +700,8 @@ def _handle_raw_request(self, raw_request: Dict): else: image_urls.append(image_url) - # Process multimodal request with VLM processor - vlm_inputs = None - has_vlm_processor = hasattr(self, "vlm_processor") and self.vlm_processor is not None - - if image_urls and has_vlm_processor: - # Use VLM processor to handle both text and images together - prompt, vlm_inputs = self._process_vlm_request(rid, messages, image_urls) - else: - # Standard text-only processing - prompt = self._process_text_request(rid, messages, raw_request) + # Process request prompt (subclasses can override for VLM support) + prompt, vlm_inputs = self._process_request_prompt(rid, messages, image_urls, raw_request) max_seq_len = self.max_sequence_length if self.max_sequence_length is not None else 4096 max_seq_len = max(max_seq_len, 4096) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index a66a3688..fc19a7dc 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -20,6 +20,7 @@ ) from parallax.server.sampling.sampler import SamplingBatchInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.config_utils import get_config_value from parallax.utils.utils import ( combine_padding_and_causal_masks, create_causal_mask, @@ -32,14 +33,6 @@ logger = get_logger(__name__) -def _get_config_value(config: Dict[str, Any], key: str, default: Any = None) -> Any: - """Get config value, falling back to text_config for VLM models.""" - if key in config and config[key] is not None: - return config[key] - text_config = config.get("text_config", {}) - return text_config.get(key, default) - - class MLXExecutor(BaseExecutor): def __init__( self, @@ -130,7 +123,7 @@ def __init__( self.model_shard = self.shard_loader.load_lora(self.model_shard, adapters) logger.debug( - f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={_get_config_value(self.config, 'num_hidden_layers')}" + f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={get_config_value(self.config, 'num_hidden_layers')}" ) # Load VLM processor if this is a VLM model (first peer only) @@ -193,11 +186,11 @@ def __init__( # Calculate feature dimensions for kv cache # Use helper to handle VLM models where these are in text_config - num_key_value_heads = _get_config_value(self.config, "num_key_value_heads") - head_dim = _get_config_value(self.config, "head_dim") + num_key_value_heads = get_config_value(self.config, "num_key_value_heads") + head_dim = get_config_value(self.config, "head_dim") if head_dim is None: - hidden_size = _get_config_value(self.config, "hidden_size") - num_attention_heads = _get_config_value(self.config, "num_attention_heads") + hidden_size = get_config_value(self.config, "hidden_size") + num_attention_heads = get_config_value(self.config, "num_attention_heads") if hidden_size and num_attention_heads: head_dim = hidden_size // num_attention_heads else: @@ -311,6 +304,139 @@ def __init__( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) + # ========== VLM Processing Methods ========== + + def _process_request_prompt( + self, rid: str, messages: list, image_urls: list, raw_request: Dict + ) -> Tuple[list, Optional[Any]]: + """ + Override base class method to handle VLM requests for MLX. + + If VLM processor is available and images are present, use VLM processing. + Otherwise, fall back to text-only processing. + """ + if image_urls and self.vlm_processor is not None: + return self._process_vlm_request(rid, messages, image_urls) + else: + prompt = self._process_text_request(rid, messages, raw_request) + return prompt, None + + def _process_vlm_request(self, rid: str, messages: list, image_urls: list): + """Process a VLM (multimodal) request using the VLM processor.""" + from parallax.server.request import VLMInputs + from parallax.utils.vlm_utils import load_image + + try: + images = [] + for url in image_urls: + try: + img = load_image(url) + images.append(img) + except Exception as e: + logger.warning(f"Failed to load image {url}: {e}") + + if not images: + logger.warning( + f"No images loaded for VLM request {rid}, falling back to text processing" + ) + return self._process_text_request(rid, messages, {}), None + + formatted_messages = self._format_messages_for_vlm(messages) + + if hasattr(self.vlm_processor, "apply_chat_template"): + text_prompt = self.vlm_processor.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + elif self.tokenizer.chat_template: + text_prompt = self.tokenizer.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + text_prompt = "\n".join( + f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" + for msg in formatted_messages + ) + + processor_inputs = self.vlm_processor( + text=text_prompt, + images=images, + return_tensors="pt", + ) + input_ids = processor_inputs.get("input_ids") + if input_ids is None: + raise ValueError("Processor did not return input_ids") + prompt = input_ids.flatten().tolist() + + pixel_values = processor_inputs.get("pixel_values") + image_grid_thw = processor_inputs.get("image_grid_thw") + + vlm_inputs = VLMInputs( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_sizes=[(img.height, img.width) for img in images], + images_processed=False, + ) + + logger.debug( + f"VLM request {rid}: {len(images)} images, " + f"input_ids length={len(prompt)}, " + f"pixel_values shape={pixel_values.shape if pixel_values is not None else None}" + ) + + return prompt, vlm_inputs + + except Exception as e: + logger.error(f"Failed to process VLM request {rid}: {e}") + return self._process_text_request(rid, messages, {}), None + + def _format_messages_for_vlm(self, messages: list) -> list: + """ + Format messages for VLM processing. + Keep the original structure so chat template can handle image placeholders correctly. + """ + formatted = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + new_content.append({"type": "text", "text": part.get("text", "")}) + elif part.get("type") == "image_url": + new_content.append({"type": "image"}) + elif isinstance(part, str): + new_content.append({"type": "text", "text": part}) + formatted.append( + { + "role": msg.get("role", "user"), + "content": new_content, + } + ) + else: + formatted.append(msg) + return formatted + + def _extract_text_from_content(self, content) -> str: + """Extract text from message content (handles both string and list formats).""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + texts.append(part.get("text", "")) + elif isinstance(part, str): + texts.append(part) + return " ".join(texts) + return str(content) + + # ========== End VLM Processing Methods ========== + def _tensor_parallel_broadcast_pyobj(self, broadcast_obj): """Wrapper for broadcast pyobject in TP group using send/recv with explicit sync""" if self.tp_size <= 1: diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index d0a1a696..1b26f654 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -19,6 +19,7 @@ from mlx_lm.utils import _download, load_config from parallax.server.model import ShardedModel +from parallax.utils.config_utils import get_config_value from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger @@ -45,14 +46,6 @@ } -def _get_config_value(config: Dict[str, Any], key: str, default: Any = None) -> Any: - """Get config value, falling back to text_config for VLM models.""" - if key in config: - return config[key] - text_config = config.get("text_config", {}) - return text_config.get(key, default) - - def _get_vlm_classes( model_type: str, config: Dict[str, Any] ) -> Tuple[Optional[Type[nn.Module]], Optional[Type[nn.Module]], Optional[Dict[str, Any]]]: @@ -458,7 +451,7 @@ def load( ] # Get tie_word_embeddings config (check both root and text_config for VLM) - tie_word_embeddings = _get_config_value(config, "tie_word_embeddings", False) + tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) for file_idx, wf in enumerate(weight_files): logger.debug( diff --git a/src/parallax/utils/config_utils.py b/src/parallax/utils/config_utils.py index b13db907..558b03cd 100644 --- a/src/parallax/utils/config_utils.py +++ b/src/parallax/utils/config_utils.py @@ -166,3 +166,68 @@ def build_mm_config(self) -> dict: "audio_token_id": self._raw_get("audio_token_id"), "vision_config": vision_config, } + + +# ============================================================================ +# Convenience functions for simple use cases +# ============================================================================ + + +def is_vlm_model(config: Any) -> bool: + """ + Check if the config represents a VLM (Vision-Language Model). + + VLM models have both text_config and vision_config. + + Args: + config: Model configuration (dict or object) + + Returns: + True if the model is a VLM + """ + if isinstance(config, dict): + text_config = config.get("text_config") + vision_config = config.get("vision_config") + else: + text_config = getattr(config, "text_config", None) + vision_config = getattr(config, "vision_config", None) + + return text_config is not None and vision_config is not None + + +def get_config_value(config: Any, key: str, default: Any = None) -> Any: + """ + Get config value with text_config fallback for VLM models. + + This is a simple function interface for cases where you don't need + the full ModelConfigAccessor functionality. + + Args: + config: Model configuration (dict or object) + key: Configuration key to look up + default: Default value if not found + + Returns: + Configuration value or default + """ + # Try root level first + if isinstance(config, dict): + value = config.get(key) + else: + value = getattr(config, key, None) + + if value is not None: + return value + + # Fallback to text_config (for VLM models) + if isinstance(config, dict): + text_config = config.get("text_config", {}) + else: + text_config = getattr(config, "text_config", {}) + + if isinstance(text_config, dict): + return text_config.get(key, default) + elif text_config is not None: + return getattr(text_config, key, default) + + return default diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index bfd8a592..a008afb2 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -3,21 +3,9 @@ from pathlib import Path from typing import Dict, List, Optional, Set -logger = logging.getLogger(__name__) - - -def _get_num_hidden_layers(config: Dict) -> int: - """Get num_hidden_layers from config, handling VLM models with text_config.""" - if "num_hidden_layers" in config: - return config["num_hidden_layers"] - # VLM models store this in text_config - text_config = config.get("text_config", {}) - return text_config.get("num_hidden_layers", 0) - +from parallax.utils.config_utils import get_config_value, is_vlm_model -def _is_vlm_model(config: Dict) -> bool: - """Check if config represents a VLM model.""" - return config.get("vision_config") is not None +logger = logging.getLogger(__name__) def should_include_weight_key( @@ -103,20 +91,13 @@ def filter_weight_files_by_layer_range_for_load( tie_word_embeddings = False if config: - tie_word_embeddings = config.get("tie_word_embeddings", False) - # Also check text_config for VLM models - if not tie_word_embeddings: - text_config = config.get("text_config", {}) - tie_word_embeddings = text_config.get("tie_word_embeddings", False) + tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) - tie_word_embeddings = cfg.get("tie_word_embeddings", False) - if not tie_word_embeddings: - text_config = cfg.get("text_config", {}) - tie_word_embeddings = text_config.get("tie_word_embeddings", False) + tie_word_embeddings = get_config_value(cfg, "tie_word_embeddings", False) needed_files: Set[str] = set() @@ -172,19 +153,19 @@ def determine_needed_weight_files_for_download( is_first_shard = start_layer == 0 is_last_shard = False - is_vlm = False + is_vlm_flag = False if config: - num_hidden_layers = _get_num_hidden_layers(config) + num_hidden_layers = get_config_value(config, "num_hidden_layers", 0) is_last_shard = end_layer >= num_hidden_layers - is_vlm = _is_vlm_model(config) + is_vlm_flag = is_vlm_model(config) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) - num_hidden_layers = _get_num_hidden_layers(cfg) + num_hidden_layers = get_config_value(cfg, "num_hidden_layers", 0) is_last_shard = end_layer >= num_hidden_layers - is_vlm = _is_vlm_model(cfg) + is_vlm_flag = is_vlm_model(cfg) index_file = model_path / "model.safetensors.index.json" @@ -212,13 +193,7 @@ def determine_needed_weight_files_for_download( logger.debug("weight_map is empty in index file") return [] - tie_word_embeddings = False - if config: - tie_word_embeddings = config.get("tie_word_embeddings", False) - # Also check text_config for VLM models - if not tie_word_embeddings: - text_config = config.get("text_config", {}) - tie_word_embeddings = text_config.get("tie_word_embeddings", False) + tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) if config else False needed_files: Set[str] = set() @@ -232,7 +207,7 @@ def determine_needed_weight_files_for_download( is_first_shard=is_first_shard, is_last_shard=is_last_shard, tie_word_embeddings=tie_word_embeddings, - is_vlm=is_vlm, + is_vlm=is_vlm_flag, ): needed_files.add(filename) From ba5d4569c8fec0108672aa02d0739821b15cb33f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 03:53:39 +0000 Subject: [PATCH 15/36] update sglang version --- pyproject.toml | 2 +- src/parallax/sglang/batch_info.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cd7e7f57..a7f8a8bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ mac = [ ] gpu = [ - "sglang[all]==0.5.7", + "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", "mlx-lm==0.28.4", "mlx[cpu]==0.30.0", ] diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index ba133fad..6eabeee7 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -134,7 +134,7 @@ def form_sgl_batch_prefill( requests, page_tree_cache, processor, hf_config, tokenizer ) - def dummy_evict(*args): + def dummy_function(*args): pass dummy_tree_cache = SimpleNamespace( @@ -142,8 +142,12 @@ def dummy_evict(*args): device=model_runner.device, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, evictable_size=0, + sliding_window_size=0, ) - dummy_tree_cache.evict = dummy_evict + dummy_tree_cache.evict = dummy_function + dummy_tree_cache.supports_swa = dummy_function + dummy_tree_cache.is_chunk_cache = dummy_function + dummy_tree_cache.supports_mamba = dummy_function schedule_batch = ScheduleBatch.init_new( reqs=sgl_reqs, req_to_token_pool=model_runner.req_to_token_pool, From 9afa8293a44381b68ff5758aedfb5ddce3620649 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 14:33:31 +0800 Subject: [PATCH 16/36] update config layers read --- src/parallax/sglang/model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 99c80d9f..8a191ade 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -40,6 +40,7 @@ from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( set_layer_range_for_filtering, ) +from parallax.utils.config_utils import ModelConfigAccessor from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -74,7 +75,9 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer - num_hidden_layers = model_config.hf_config.num_hidden_layers + num_hidden_layers = ModelConfigAccessor(model_config.hf_config).get_num_hidden_layers() + if num_hidden_layers is None: + raise ValueError("num_hidden_layers is required but not found in model config") set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) super().__init__( From ab53736bbf9297356f0b5765a025342fa9ce56b7 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 06:47:37 +0000 Subject: [PATCH 17/36] specfic for kimi --- src/parallax/sglang/batch_info.py | 10 ++++++---- src/parallax/sglang/multimodal_utils.py | 13 +++++++++++-- src/parallax/utils/config_utils.py | 8 +++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 6eabeee7..819b9d1f 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -85,15 +85,17 @@ def transform_requests_to_sglang( else: logger.warning( - f"Assigning raw multimodal_params to req.multimodal_inputs. " - f"SGLang might expect MultimodalInputs object with Tensors. " + f"Cannot process multimodal_params: no 'mm_items' or 'images' key, " + f"or processor not available. " f"Params keys: {old_req.multimodal_params.keys()}, Processor: {processor is not None}" ) - req.multimodal_inputs = old_req.multimodal_params + # Don't assign raw dict - SGLang expects MultimodalInputs object + req.multimodal_inputs = None except Exception as e: logger.exception(f"Failed to construct MultimodalInputs: {e}") - req.multimodal_inputs = old_req.multimodal_params + # Don't assign raw dict - SGLang expects MultimodalInputs object + req.multimodal_inputs = None # Debug: Log before cache lookup if page_tree_cache is not None: diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 0840fe97..3b7a0a23 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -64,7 +64,15 @@ def process_images( return [], None, [] try: - inputs = processor(text=input_text, images=images, return_tensors="pt") + # Check if this is a Kimi K2.5 processor (has different interface) + processor_class_name = processor.__class__.__name__ + if processor_class_name == "KimiK25Processor": + # Kimi K2.5 requires 'medias' parameter with specific format + medias = [{"type": "image", "image": img} for img in images] + inputs = processor(medias=medias, text=input_text, return_tensors="pt") + else: + # Standard HuggingFace processor interface + inputs = processor(text=input_text, images=images, return_tensors="pt") if inputs is None: logger.error("Processor returned None") @@ -83,7 +91,8 @@ def process_images( logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") - image_grid_thw = inputs.get("image_grid_thw") + # Handle different field names: Kimi K2.5 uses 'grid_thws', others use 'image_grid_thw' + image_grid_thw = inputs.get("image_grid_thw") or inputs.get("grid_thws") image_sizes = inputs.get("image_sizes") model_specific_data = {} diff --git a/src/parallax/utils/config_utils.py b/src/parallax/utils/config_utils.py index 558b03cd..7e65fe50 100644 --- a/src/parallax/utils/config_utils.py +++ b/src/parallax/utils/config_utils.py @@ -157,9 +157,15 @@ def build_mm_config(self) -> dict: "tokens_per_second": getattr(vision_config_raw, "tokens_per_second", None), } + # Get image_token_id with fallbacks for different models + # Kimi K2.5 uses 'media_placeholder_token_id' instead of 'image_token_id' + image_token_id = self._raw_get("image_token_id") + if image_token_id is None: + image_token_id = self._raw_get("media_placeholder_token_id") + return { "model_type": self._raw_get("model_type"), - "image_token_id": self._raw_get("image_token_id"), + "image_token_id": image_token_id, "vision_start_token_id": self._raw_get("vision_start_token_id"), "vision_end_token_id": self._raw_get("vision_end_token_id"), "video_token_id": self._raw_get("video_token_id"), From 28001b8b449312f5fc75728e1c7fad8ad9f8e228 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 07:06:39 +0000 Subject: [PATCH 18/36] only print rank0 log and fix kimi token expand --- src/parallax/server/executor/base_executor.py | 6 +++- src/parallax/sglang/multimodal_utils.py | 34 ++++++++++++++++-- src/parallax_utils/logging_config.py | 35 ++++++++++++++++++- 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index afa73af9..0c01f3e5 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -46,7 +46,7 @@ from parallax.utils.config_utils import ModelConfigAccessor from parallax.utils.shared_state import SharedState from parallax.utils.utils import get_current_device, get_device_dtype, get_zmq_socket -from parallax_utils.logging_config import get_logger +from parallax_utils.logging_config import get_logger, set_rank logger = get_logger(__name__) @@ -128,6 +128,10 @@ def __init__( self.dp_size = dp_size self.dp_rank = dp_rank + # Configure logging to only print on rank 0 when using multiple GPUs + if tp_size > 1: + set_rank(tp_rank, enable_filter=True) + # Runtime weight refit for RL self.enable_weight_refit = enable_weight_refit self.weight_version = 0 diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 3b7a0a23..d5e7062c 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -67,6 +67,31 @@ def process_images( # Check if this is a Kimi K2.5 processor (has different interface) processor_class_name = processor.__class__.__name__ if processor_class_name == "KimiK25Processor": + # Kimi K2.5 requires special handling: + # 1. Expand image tokens based on actual image size + # 2. Use 'medias' parameter instead of 'images' + + # The image token for Kimi K2.5 is <|media_pad|> + image_token = "<|media_pad|>" + + # Expand single image tokens to the correct number of tokens + if image_token in input_text and hasattr(processor, 'media_processor'): + parts = input_text.split(image_token) + result = [parts[0]] + for i, (image, part) in enumerate(zip(images, parts[1:])): + try: + # Calculate how many tokens this image needs + num_tokens = processor.media_processor.media_tokens_calculator( + {"type": "image", "image": image} + ) + logger.debug(f"Kimi K2.5: Image {i} expanded to {num_tokens} tokens") + except Exception as e: + logger.warning(f"Failed to calculate media tokens for image {i}: {e}, using 1") + num_tokens = 1 + result.append(image_token * num_tokens + part) + input_text = "".join(result) + logger.debug(f"Kimi K2.5: Expanded input_text length: {len(input_text)}") + # Kimi K2.5 requires 'medias' parameter with specific format medias = [{"type": "image", "image": img} for img in images] inputs = processor(medias=medias, text=input_text, return_tensors="pt") @@ -92,12 +117,17 @@ def process_images( logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") # Handle different field names: Kimi K2.5 uses 'grid_thws', others use 'image_grid_thw' + is_kimi_k25 = processor_class_name == "KimiK25Processor" image_grid_thw = inputs.get("image_grid_thw") or inputs.get("grid_thws") image_sizes = inputs.get("image_sizes") + # Determine the correct field name for grid data + # Kimi K2.5 expects 'grid_thws', others expect 'image_grid_thw' + grid_field_name = "grid_thws" if is_kimi_k25 else "image_grid_thw" + model_specific_data = {} if image_grid_thw is not None: - model_specific_data["image_grid_thw"] = image_grid_thw + model_specific_data[grid_field_name] = image_grid_thw if image_sizes is not None: model_specific_data["image_sizes"] = image_sizes @@ -123,7 +153,7 @@ def process_images( modality=Modality.IMAGE, feature=item_pixel_values, model_specific_data={ - "image_grid_thw": item_grid_thw, + grid_field_name: item_grid_thw, }, ) mm_items.append(item) diff --git a/src/parallax_utils/logging_config.py b/src/parallax_utils/logging_config.py index 13a91937..81962a43 100644 --- a/src/parallax_utils/logging_config.py +++ b/src/parallax_utils/logging_config.py @@ -6,10 +6,12 @@ import threading from typing import Optional -__all__ = ["get_logger", "use_parallax_log_handler", "set_log_level"] +__all__ = ["get_logger", "use_parallax_log_handler", "set_log_level", "set_rank"] _init_lock = threading.Lock() _default_handler: logging.Handler | None = None +_current_rank: int = 0 # Default to rank 0 (will print logs) +_rank_filter_enabled: bool = False # Whether to filter logs by rank class _Ansi: @@ -138,3 +140,34 @@ def use_parallax_log_handler(for_root: bool = True): root = logging.getLogger() if _default_handler not in root.handlers: root.addHandler(_default_handler) + + +def set_rank(rank: int, enable_filter: bool = True): + """ + Set the current process rank for log filtering. + + When rank filtering is enabled, only rank 0 will print logs. + This is useful for multi-GPU (TP/DP) scenarios where you want + to avoid duplicate log messages from all processes. + + Args: + rank: The rank of the current process (0 for master). + enable_filter: If True, only rank 0 will print logs. + """ + global _current_rank, _rank_filter_enabled + _current_rank = rank + _rank_filter_enabled = enable_filter + + if enable_filter and rank != 0: + # Disable logging for non-zero ranks by setting to a high level + logging.getLogger().setLevel(logging.CRITICAL + 1) + + +def get_rank() -> int: + """Get the current process rank.""" + return _current_rank + + +def is_rank_zero() -> bool: + """Check if the current process is rank 0.""" + return _current_rank == 0 From be6d1f16c0c7962242b53a71c92963fcc2d7c241 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Thu, 5 Feb 2026 07:34:38 +0000 Subject: [PATCH 19/36] fix weight fiter --- src/parallax/server/executor/factory.py | 17 ++++++++++++++++- src/parallax/sglang/model_runner.py | 6 ++++-- .../monkey_patch_utils/weight_loader_filter.py | 10 +++++++++- src/parallax/sglang/multimodal_utils.py | 18 ++++++++++++++++++ src/parallax/utils/weight_filter_utils.py | 2 +- src/parallax/vllm/model_runner.py | 6 ++++-- 6 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index 20ed93e4..ac8c5418 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional from parallax.utils.utils import get_current_device -from parallax_utils.logging_config import get_logger, set_log_level +from parallax_utils.logging_config import get_logger, set_log_level, set_rank logger = get_logger(__name__) @@ -109,7 +109,22 @@ def create_from_args( def run_executor_process(args, shared_state=None, conn=None): """Run executor as a subprocess""" + # Set rank to suppress logs on non-zero ranks + # Must be called AFTER set_log_level to override the level + tp_rank = getattr(args, 'tp_rank', 0) + tp_size = getattr(args, 'tp_size', 1) + + # For non-zero ranks, suppress logs before any imports + if tp_size > 1 and tp_rank != 0: + import logging + logging.getLogger().setLevel(logging.CRITICAL + 1) + set_log_level(args.log_level) + + # Now set rank properly (will re-suppress for non-zero ranks) + if tp_size > 1: + set_rank(tp_rank, enable_filter=True) + executor = None try: executor = create_from_args(args, shared_state, conn) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 8a191ade..fd5a6cb9 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -75,10 +75,12 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer - num_hidden_layers = ModelConfigAccessor(model_config.hf_config).get_num_hidden_layers() + config_accessor = ModelConfigAccessor(model_config.hf_config) + num_hidden_layers = config_accessor.get_num_hidden_layers() if num_hidden_layers is None: raise ValueError("num_hidden_layers is required but not found in model config") - set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) + is_vlm = config_accessor.is_vlm + set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers, is_vlm=is_vlm) super().__init__( model_config=model_config, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 7bd48082..f1d9e318 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -11,11 +11,17 @@ _layer_range_cache = {} -def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, num_hidden_layers: int): +def set_layer_range_for_filtering( + pp_start_layer: int, + pp_end_layer: int, + num_hidden_layers: int, + is_vlm: bool = False, +): global _layer_range_cache _layer_range_cache["pp_start_layer"] = pp_start_layer _layer_range_cache["pp_end_layer"] = pp_end_layer _layer_range_cache["num_hidden_layers"] = num_hidden_layers + _layer_range_cache["is_vlm"] = is_vlm def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: @@ -24,6 +30,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: pp_start_layer = _layer_range_cache.get("pp_start_layer") pp_end_layer = _layer_range_cache.get("pp_end_layer") num_hidden_layers = _layer_range_cache.get("num_hidden_layers") + is_vlm = _layer_range_cache.get("is_vlm", False) if pp_start_layer is None or pp_end_layer is None: logger.debug("No layer range set, loading all weight files") @@ -43,6 +50,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: end_layer=pp_end_layer, is_first_shard=is_first_shard, is_last_shard=is_last_shard, + is_vlm=is_vlm, ) return filtered_files diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index d5e7062c..2cccc1dc 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -34,6 +34,8 @@ def load_image(url: Any) -> Image.Image: return Image.open(image_data).convert("RGB") elif url.startswith("data:image"): header, encoded = url.split(",", 1) + # Strip whitespace from base64 encoded data (some clients add spaces after comma) + encoded = encoded.strip() image_data = BytesIO(base64.b64decode(encoded)) return Image.open(image_data).convert("RGB") else: @@ -56,6 +58,7 @@ def process_images( try: image = load_image(url) images.append(image) + logger.debug(f"Loaded image: size={image.size}, mode={image.mode}") except Exception as e: logger.exception(f"Failed to load image {url}: {e}") continue @@ -103,10 +106,23 @@ def process_images( logger.error("Processor returned None") return [], None, [] + # Debug: Log all keys returned by the processor + logger.debug(f"Processor output keys: {list(inputs.keys())}") + for key, value in inputs.items(): + if hasattr(value, 'shape'): + logger.debug(f" {key}: shape={value.shape}, dtype={value.dtype}") + elif hasattr(value, '__len__'): + logger.debug(f" {key}: len={len(value)}, type={type(value)}") + else: + logger.debug(f" {key}: {type(value)}") + pixel_values = inputs.get("pixel_values") if pixel_values is None: logger.error("Processor output missing pixel_values") return [], None, [] + + logger.debug(f"pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}, device: {pixel_values.device}") + logger.debug(f"pixel_values stats: min={pixel_values.min().item():.4f}, max={pixel_values.max().item():.4f}, mean={pixel_values.mean().item():.4f}") expanded_input_ids = inputs.get("input_ids") if expanded_input_ids is not None: @@ -120,6 +136,8 @@ def process_images( is_kimi_k25 = processor_class_name == "KimiK25Processor" image_grid_thw = inputs.get("image_grid_thw") or inputs.get("grid_thws") image_sizes = inputs.get("image_sizes") + + logger.debug(f"image_grid_thw: {image_grid_thw}, is_kimi_k25: {is_kimi_k25}") # Determine the correct field name for grid data # Kimi K2.5 expects 'grid_thws', others expect 'image_grid_thw' diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index a008afb2..b0265498 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -129,7 +129,7 @@ def filter_weight_files_by_layer_range_for_load( logger.debug( f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " - f"for layers [{start_layer}, {end_layer})" + f"for layers [{start_layer}, {end_layer}), is_vlm={is_vlm}, is_first_shard={is_first_shard}" ) # If filtering resulted in no files but we had input files, diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index c7963fd9..43a507e9 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -401,10 +401,12 @@ def initialize_vllm_model_runner( # Reuse the generic monkey patch used by sglang implementation to reduce # local weight file reads when loading a partial layer shard. try: - set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers) + from parallax.utils.config_utils import is_vlm_model + is_vlm = is_vlm_model(config) + set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers, is_vlm=is_vlm) apply_weight_loader_filter_patch() logger.debug( - f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer})" + f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer}), is_vlm={is_vlm}" ) except Exception as e: logger.warning("Failed to apply weight loader filter patch for vLLM loading: %s", e) From 4e2b1d97d69b6bece0ddd2441da4a0c4122f2c9f Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 6 Feb 2026 16:12:31 +0800 Subject: [PATCH 20/36] update mlx-lm --- Makefile | 14 ++++++++++++++ docs/user_guide/install.md | 6 +++--- pyproject.toml | 4 +--- 3 files changed, 18 insertions(+), 6 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..1f47f2f9 --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +.PHONY: install-gpu install-mac install-vllm install-dev + +install-gpu: + pip install -e ".[gpu]" + pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps + +install-mac: + pip install -e ".[mac]" + +install-vllm: + pip install -e ".[vllm]" + +install-dev: + pip install -e ".[dev]" diff --git a/docs/user_guide/install.md b/docs/user_guide/install.md index 02a30343..dd7b1667 100644 --- a/docs/user_guide/install.md +++ b/docs/user_guide/install.md @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio ```sh git clone https://github.com/GradientHQ/parallax.git cd parallax -pip install -e '.[gpu]' +make install-gpu ``` #### For macOS (Apple silicon): @@ -34,14 +34,14 @@ cd parallax python3 -m venv ./venv source ./venv/bin/activate -pip install -e '.[mac]' +make install-mac ``` Next time to re-activate this virtual environment, run ```source ./venv/bin/activate```. #### Extra step for development: ```sh -pip install -e '.[dev]' +make install-dev ``` ### Windows Application diff --git a/pyproject.toml b/pyproject.toml index a7f8a8bc..5d13754c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ parallax = "parallax.cli:main" mac = [ "nanobind==2.10.2", "torch==2.8.0", - "mlx-lm==0.30.5", + "mlx-lm==0.30.6", "mlx==0.30.4", "mlx-vlm==0.3.10", "torchvision==0.23.0" @@ -54,8 +54,6 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", - "mlx-lm==0.28.4", - "mlx[cpu]==0.30.0", ] vllm = [ From fbbde634a58465bbe8e25ebb60f9e4e03ca7f310 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 6 Feb 2026 16:25:42 +0800 Subject: [PATCH 21/36] add tool test --- tests/test_tool_call.py | 503 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 503 insertions(+) create mode 100644 tests/test_tool_call.py diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py new file mode 100644 index 00000000..fbec2c82 --- /dev/null +++ b/tests/test_tool_call.py @@ -0,0 +1,503 @@ +""" +Tests for ToolCallState parsing in parallax. + +Tests that ToolCallState correctly detects, extracts, and formats +tool calls from model output segments. +""" + +import json +import unittest +from unittest.mock import MagicMock + +from parallax.utils.tokenizer_utils import ToolCallState + + +def make_tool_state( + tool_call_start: str, + tool_call_end: str, + tool_parser, + tools=None, + stream=False, +) -> ToolCallState: + """Helper to create a ToolCallState with the given parser config.""" + return ToolCallState( + has_tool_calling=True, + tool_call_start=tool_call_start, + tool_call_end=tool_call_end, + tool_parser=tool_parser, + tools=tools, + stream=stream, + ) + + +# ---- Parsers mimicking common model formats ---- + +def json_tool_parser(tool_text: str, tools): + """Parser for JSON-formatted tool calls (e.g. Qwen, GLM).""" + parsed = json.loads(tool_text) + return { + "name": parsed["name"], + "arguments": parsed["arguments"], + } + + +def xml_tool_parser(tool_text: str, tools): + """Parser for XML-formatted tool calls (e.g. minimax style).""" + import re + + name_match = re.search(r'', tool_text) + if not name_match: + raise ValueError("No invoke name found") + name = name_match.group(1) + params = dict(re.findall(r'([^<]+)', tool_text)) + # Try to convert numeric values + for k, v in params.items(): + try: + params[k] = int(v) + except ValueError: + try: + params[k] = float(v) + except ValueError: + pass + return {"name": name, "arguments": params} + + +def multi_tool_parser(tool_text: str, tools): + """Parser that returns a list of tool calls (e.g. kimi-k2 style).""" + calls = [] + for part in tool_text.split("}{"): + part = part.strip() + if not part.startswith("{"): + part = "{" + part + if not part.endswith("}"): + part = part + "}" + parsed = json.loads(part) + calls.append({ + "id": parsed.get("id", "call_0"), + "name": parsed["name"], + "arguments": parsed["arguments"], + }) + return calls + + +# ---- Sample tools definition ---- + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location.", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": {"type": "string", "description": "The city name."}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers.", + "parameters": { + "type": "object", + "required": ["a", "b"], + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + }, + }, + }, +] + + +class TestToolCallStateInit(unittest.TestCase): + """Test ToolCallState initialization via from_tokenizer.""" + + def test_from_tokenizer_with_tool_support(self): + """Tokenizer with full tool calling support should produce active state.""" + tokenizer = MagicMock() + tokenizer.has_tool_calling = True + tokenizer.tool_parser = json_tool_parser + tokenizer.tool_call_start = "" + tokenizer.tool_call_end = "" + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertTrue(state.has_tool_calling) + self.assertEqual(state.tool_call_start, "") + self.assertEqual(state.tool_call_end, "") + self.assertIs(state.tool_parser, json_tool_parser) + self.assertEqual(state.tools, SAMPLE_TOOLS) + + def test_from_tokenizer_without_tool_support(self): + """Tokenizer without tool calling attributes should produce inactive state.""" + tokenizer = MagicMock(spec=[]) # No attributes + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertFalse(state.has_tool_calling) + self.assertIsNone(state.tool_call_start) + self.assertIsNone(state.tool_call_end) + self.assertIsNone(state.tool_parser) + + def test_from_tokenizer_partial_support(self): + """Tokenizer with partial attributes (missing tool_call_end) should be inactive.""" + tokenizer = MagicMock() + tokenizer.has_tool_calling = True + tokenizer.tool_parser = json_tool_parser + tokenizer.tool_call_start = "" + tokenizer.tool_call_end = None # Missing end marker + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertFalse(state.has_tool_calling) + + +class TestToolCallExtraction(unittest.TestCase): + """Test extract_from_segment with various tool call formats.""" + + def test_json_tool_call(self): + """Test JSON-formatted tool call extraction.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual( + json.loads(tool_calls[0]["function"]["arguments"]), + {"location": "Beijing"}, + ) + self.assertEqual(tool_calls[0]["type"], "function") + self.assertIn("id", tool_calls[0]) + + def test_xml_tool_call(self): + """Test XML-formatted tool call extraction.""" + state = make_tool_state( + tool_call_start="<|tool_start|>", + tool_call_end="<|tool_end|>", + tool_parser=xml_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '<|tool_start|>' + '' + '12345' + '67890' + '' + '<|tool_end|>' + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "multiply") + self.assertEqual( + json.loads(tool_calls[0]["function"]["arguments"]), + {"a": 12345, "b": 67890}, + ) + + def test_text_before_tool_call(self): + """Tool call preceded by regular text should preserve the text.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = 'Let me check the weather for you.{"name": "get_weather", "arguments": {"location": "Shanghai"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "Let me check the weather for you.") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + + def test_text_after_tool_call(self): + """Text after tool call end marker should be preserved.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "Tokyo"}} Done!' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, " Done!") + self.assertEqual(len(tool_calls), 1) + + def test_no_tool_call(self): + """Segment without tool call markers should pass through unchanged.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "The weather in Beijing is sunny today." + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, segment) + self.assertEqual(tool_calls, []) + + def test_inactive_state_passes_through(self): + """Inactive ToolCallState (no tool support) should pass through text unchanged.""" + state = ToolCallState( + has_tool_calling=False, + tool_call_start=None, + tool_call_end=None, + tool_parser=None, + tools=None, + stream=False, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, segment) + self.assertEqual(tool_calls, []) + + def test_empty_segment(self): + """Empty segment should return empty text and no tool calls.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + text, tool_calls = state.extract_from_segment("") + self.assertEqual(text, "") + self.assertEqual(tool_calls, []) + + +class TestToolCallStreaming(unittest.TestCase): + """Test streaming scenarios where tool calls arrive across multiple segments.""" + + def test_tool_call_split_across_segments(self): + """Tool call split across two segments should accumulate and parse correctly.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + # First segment: start marker + partial content + text1, calls1 = state.extract_from_segment( + '{"name": "get_weather",' + ) + self.assertEqual(text1, "") + self.assertEqual(calls1, []) + self.assertTrue(state.in_tool_call) + + # Second segment: rest of content + end marker + text2, calls2 = state.extract_from_segment( + ' "arguments": {"location": "Beijing"}}' + ) + self.assertEqual(text2, "") + self.assertEqual(len(calls2), 1) + self.assertEqual(calls2[0]["function"]["name"], "get_weather") + self.assertFalse(state.in_tool_call) + + def test_tool_call_split_three_segments(self): + """Tool call split across three segments.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + text1, calls1 = state.extract_from_segment("") + self.assertEqual(calls1, []) + self.assertTrue(state.in_tool_call) + + text2, calls2 = state.extract_from_segment( + '{"name": "multiply", "arguments": {"a": 3, "b": 7}}' + ) + self.assertEqual(calls2, []) + + text3, calls3 = state.extract_from_segment("") + self.assertEqual(len(calls3), 1) + self.assertEqual(calls3[0]["function"]["name"], "multiply") + + def test_stream_mode_has_index(self): + """In stream mode, tool calls should have an 'index' field.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + stream=True, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "NYC"}}' + _, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(len(tool_calls), 1) + self.assertIn("index", tool_calls[0]) + self.assertEqual(tool_calls[0]["index"], 0) + + def test_stream_mode_index_increments(self): + """In stream mode, tool call index should increment for each call.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + stream=True, + ) + + # First tool call + segment1 = '{"name": "get_weather", "arguments": {"location": "NYC"}}' + _, calls1 = state.extract_from_segment(segment1) + self.assertEqual(calls1[0]["index"], 0) + + # Second tool call + segment2 = '{"name": "multiply", "arguments": {"a": 2, "b": 3}}' + _, calls2 = state.extract_from_segment(segment2) + self.assertEqual(calls2[0]["index"], 1) + + +class TestToolCallMultiple(unittest.TestCase): + """Test multiple tool calls in a single segment.""" + + def test_two_tool_calls_in_one_segment(self): + """Two consecutive tool calls in one segment should both be extracted.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + '{"name": "multiply", "arguments": {"a": 3, "b": 5}}' + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual(tool_calls[1]["function"]["name"], "multiply") + + def test_multi_return_parser(self): + """Parser that returns a list of tool calls should all be formatted.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=multi_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"id": "call_1", "name": "get_weather", "arguments": {"location": "London"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual(tool_calls[0]["id"], "call_1") + + +class TestToolCallEdgeCases(unittest.TestCase): + """Test edge cases and error handling.""" + + def test_malformed_json_falls_back(self): + """Malformed JSON inside tool call markers should fall back to raw text.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "not valid json at all" + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "not valid json at all") + self.assertEqual(tool_calls, []) + + def test_made_tool_call_flag(self): + """made_tool_call flag should be set after a tool call is detected.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + self.assertFalse(state.made_tool_call) + + segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + state.extract_from_segment(segment) + + self.assertTrue(state.made_tool_call) + + def test_made_tool_call_flag_set_even_on_parse_failure(self): + """made_tool_call should be True even if parsing fails (marker was still entered).""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "invalid" + state.extract_from_segment(segment) + + self.assertTrue(state.made_tool_call) + + def test_string_arguments_serialized_to_json(self): + """Tool call arguments dict should be serialized to JSON string in output.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "北京"}}' + _, tool_calls = state.extract_from_segment(segment) + + # arguments should be a JSON string (not a dict) in the output + args_str = tool_calls[0]["function"]["arguments"] + self.assertIsInstance(args_str, str) + self.assertEqual(json.loads(args_str), {"location": "北京"}) + + def test_unicode_in_arguments(self): + """Unicode characters in arguments should be preserved (ensure_ascii=False).""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "東京"}}' + _, tool_calls = state.extract_from_segment(segment) + + args_str = tool_calls[0]["function"]["arguments"] + # ensure_ascii=False means the Chinese/Japanese chars should appear directly + self.assertIn("東京", args_str) + + +if __name__ == "__main__": + unittest.main() From a410ed4df53e15a18b67c5cd04d41a43206fb1f8 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 6 Feb 2026 16:54:51 +0800 Subject: [PATCH 22/36] rebase pip install --- Makefile | 14 -- docs/user_guide/install.md | 6 +- pyproject.toml | 2 + src/parallax/models/kimi_vl.py | 203 ++++++++++++++++++++++++++++ src/parallax/server/model.py | 40 +++++- src/parallax/server/shard_loader.py | 8 +- 6 files changed, 249 insertions(+), 24 deletions(-) delete mode 100644 Makefile create mode 100644 src/parallax/models/kimi_vl.py diff --git a/Makefile b/Makefile deleted file mode 100644 index 1f47f2f9..00000000 --- a/Makefile +++ /dev/null @@ -1,14 +0,0 @@ -.PHONY: install-gpu install-mac install-vllm install-dev - -install-gpu: - pip install -e ".[gpu]" - pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps - -install-mac: - pip install -e ".[mac]" - -install-vllm: - pip install -e ".[vllm]" - -install-dev: - pip install -e ".[dev]" diff --git a/docs/user_guide/install.md b/docs/user_guide/install.md index dd7b1667..7a671ec8 100644 --- a/docs/user_guide/install.md +++ b/docs/user_guide/install.md @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio ```sh git clone https://github.com/GradientHQ/parallax.git cd parallax -make install-gpu +pip install -e ".[gpu]" ``` #### For macOS (Apple silicon): @@ -34,14 +34,14 @@ cd parallax python3 -m venv ./venv source ./venv/bin/activate -make install-mac +pip install -e ".[mac]" ``` Next time to re-activate this virtual environment, run ```source ./venv/bin/activate```. #### Extra step for development: ```sh -make install-dev +pip install -e ".[dev]" ``` ### Windows Application diff --git a/pyproject.toml b/pyproject.toml index 5d13754c..38a1d8d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", + # due to transformers version conflict, we need to install mlx-lm and mlx separately + # pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps ] vllm = [ diff --git a/src/parallax/models/kimi_vl.py b/src/parallax/models/kimi_vl.py new file mode 100644 index 00000000..b6c2057f --- /dev/null +++ b/src/parallax/models/kimi_vl.py @@ -0,0 +1,203 @@ +""" +Defines the KimiVL model for Parallax. + +KimiVL uses a DeepSeek-V3 based language model with MoE and a MoonViT vision encoder. +This module reuses components from mlx-vlm and adds PagedAttention support +for distributed inference. +""" + +from typing import Any, List, Optional + +import mlx.core as mx +from mlx import nn +from mlx_lm.models.base import scaled_dot_product_attention + +# Import from mlx-vlm kimi_vl language module +from mlx_vlm.models.kimi_vl.language import DeepseekV3Attention as MLXKimiVLAttention +from mlx_vlm.models.kimi_vl.language import DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer +from mlx_vlm.models.kimi_vl.language import DeepseekV3MLP, DeepseekV3MoE + +from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache +from parallax.utils.prefix_cache_utils import compute_attention_with_prefix_cache +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ParallaxKimiVLAttention(MLXKimiVLAttention): + """KimiVL (DeepSeek-V3) Attention with PagedAttention support for Parallax. + + This extends the MLX-VLM KimiVL attention (DeepseekV3Attention) with: + - Paged KV cache support for efficient memory management + - Block-table based attention for decode phase + - Prefix cache support for prefill phase + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[BaseCache] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + prefix_lens: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + batch, target_len, _ = x.shape + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + + # KV projection (with MQA compression) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + + kv = kv.reshape(batch, target_len, self.num_heads, -1) + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + k_nope = k_nope.transpose(0, 2, 1, 3) + + # Get KV cache + key_cache_global, value_cache_global = cache.get_cache() + + # Compute RoPE offsets + if target_len == 1: + # Decode phase: position is context_length - 1 + current_pos = context_lengths - 1 + elif prefix_lens is not None: + # Prefill phase with prefix cache + current_pos = prefix_lens + else: + # Prefill phase without prefix cache + current_pos = 0 + + # Apply RoPE + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) + + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + # Cache update with PagedAttention + block_size = key_cache_global.shape[3] + + reshape_and_cache( + keys.transpose(0, 2, 1, 3), + values, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, + ) + + if target_len == 1: + # Decode phase: Use Paged Attention + output = paged_attention( + queries, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + self.num_heads, + v_head_dim=values.shape[-1], + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + else: + # Prefill phase + has_prefix_cache = prefix_lens is not None and bool(mx.any(prefix_lens > 0)) + + if has_prefix_cache: + k_new = keys + v_new = values.transpose(0, 2, 1, 3) + output = compute_attention_with_prefix_cache( + queries, + k_new, + v_new, + cache, + block_tables, + prefix_lens, + target_len, + self.scale, + self.num_heads, + mask=mask, + ) + else: + # Standard self-attention + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) + + output = scaled_dot_product_attention( + queries, + keys, + values.transpose(0, 2, 1, 3), + scale=self.scale, + mask=mask, + cache=None, + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + + return self.o_proj(output) + + +class ParallaxKimiVLBlock(MLXKimiVLDecoderLayer): + """KimiVL Transformer block with PagedAttention support. + + Extends the MLX-VLM KimiVL decoder layer to use ParallaxKimiVLAttention + and pass through paged attention arguments. + """ + + def __init__(self, args, layer_idx: int, local_layer_idx: int): + super().__init__(args, layer_idx=layer_idx) + # Replace attention with Parallax version + self.self_attn = ParallaxKimiVLAttention(args) + self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + lengths: Optional[mx.array] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "KimiVLForConditionalGeneration" + + +EntryClass = ParallaxKimiVLBlock diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index aee512f3..caa349e0 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -121,7 +121,20 @@ def __init__( # Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel # In these cases, multi_modal_projector_class can be None if multi_modal_projector_class is not None: - self.multi_modal_projector = multi_modal_projector_class(config) + # Some projectors (e.g., KimiVL) need both vision_config and text_config + # Create a combined config object if the projector expects it + try: + self.multi_modal_projector = multi_modal_projector_class(config) + except (TypeError, AttributeError): + # Projector expects a combined config with vision_config + text_config + combined_config = type("CombinedConfig", (), { + "vision_config": self.vision_config, + "text_config": config, + })() + self.multi_modal_projector = multi_modal_projector_class(combined_config) + logger.info( + "Initialized projector with combined vision+text config" + ) else: self.multi_modal_projector = None logger.info( @@ -220,9 +233,11 @@ def _encode_images( if self.vision_tower is None: raise ValueError("Vision tower not initialized for this model") - # Check if this is a Qwen-VL style model (needs grid_thw) + # Check if this is a model that uses grid_thw for vision encoding model_type = getattr(self.vision_config, "model_type", "") if self.vision_config else "" is_qwen_vl = "qwen" in model_type.lower() and "vl" in model_type.lower() + is_moonvit = model_type.lower() == "moonvit" + uses_grid_thw = is_qwen_vl or is_moonvit # Ensure correct dtype if hasattr(self.vision_tower, "patch_embed") and hasattr( @@ -234,13 +249,26 @@ def _encode_images( pixel_values = pixel_values.astype(self.dtype) # Get vision features from vision tower - if is_qwen_vl and image_grid_thw is not None: - # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> (hidden_states, deepstack_features) - # No format conversion needed - Qwen-VL expects flat patches - vision_outputs = self.vision_tower(pixel_values, image_grid_thw) + if uses_grid_thw and image_grid_thw is not None: + if is_moonvit: + # KimiVL/MoonViT style: VisionModel expects NHWC input and grid_thw + # pixel_values may be NCHW from processor, convert to NHWC + if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: + pixel_values = pixel_values.transpose(0, 2, 3, 1) + vision_outputs = self.vision_tower( + pixel_values, grid_thw=image_grid_thw, output_hidden_states=True + ) + else: + # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> hidden_states + # No format conversion needed - Qwen-VL expects flat patches + vision_outputs = self.vision_tower(pixel_values, image_grid_thw) + if isinstance(vision_outputs, tuple): # First element is the merged hidden states (already projected by merger) selected_features = vision_outputs[0] + elif isinstance(vision_outputs, list): + # KimiVL patch_merger returns a list of arrays + selected_features = vision_outputs else: selected_features = vision_outputs else: diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 1b26f654..e820df02 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -35,6 +35,7 @@ "qwen3_vl": "qwen3", "qwen2_vl": "qwen2", "qwen2_5_vl": "qwen2", + "kimi_vl": "deepseek_v3", } # VLM models that need special handling (have separate projector class) @@ -43,6 +44,7 @@ VLM_SPECIAL_PROJECTOR_MAP = { "llava": ("mlx_vlm.models.llava.llava", "LlavaMultiModalProjector"), "llava_next": ("mlx_vlm.models.llava_next.llava_next", "LlavaMultiModalProjector"), + "kimi_vl": ("mlx_vlm.models.kimi_vl.kimi_vl", "KimiVLMultiModalProjector"), } @@ -373,7 +375,11 @@ def load( is_vlm = vision_config is not None and vision_tower_class is not None # Get VLM-specific config parameters - image_token_index = config.get("image_token_index") or config.get("image_token_id") + image_token_index = ( + config.get("image_token_index") + or config.get("image_token_id") + or config.get("media_placeholder_token_id") # KimiVL uses this name + ) vision_feature_layer = config.get("vision_feature_layer", -2) vision_feature_select_strategy = config.get("vision_feature_select_strategy", "default") From f353ff7a2e1a26fafdfa5b24dcca2187a1876631 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 6 Feb 2026 17:03:17 +0800 Subject: [PATCH 23/36] update pyprojection --- docs/user_guide/install.md | 2 +- pyproject.toml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/install.md b/docs/user_guide/install.md index 7a671ec8..ecfbf16d 100644 --- a/docs/user_guide/install.md +++ b/docs/user_guide/install.md @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio ```sh git clone https://github.com/GradientHQ/parallax.git cd parallax -pip install -e ".[gpu]" +pip install -e ".[gpu]" && pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps ``` #### For macOS (Apple silicon): diff --git a/pyproject.toml b/pyproject.toml index 38a1d8d8..1ba0f035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", + "mlx-lm==0.28.4", + "mlx[cpu]==0.30.0", # due to transformers version conflict, we need to install mlx-lm and mlx separately # pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps ] From 2dea44bd07958cf215dbf4e009401df4ecfbb7b9 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Fri, 6 Feb 2026 19:15:18 +0800 Subject: [PATCH 24/36] fix pre commit --- src/parallax/models/kimi_vl.py | 6 +- src/parallax/server/executor/factory.py | 13 +-- src/parallax/server/executor/mlx_executor.py | 12 --- src/parallax/server/model.py | 96 ++++--------------- src/parallax/server/request.py | 23 +---- src/parallax/server/shard_loader.py | 38 +------- src/parallax/sglang/model_runner.py | 4 +- .../weight_loader_filter.py | 4 +- src/parallax/sglang/multimodal_utils.py | 28 +++--- src/parallax/utils/weight_filter_utils.py | 14 +-- src/parallax/vllm/model_runner.py | 1 + src/parallax_utils/logging_config.py | 6 +- tests/test_tool_call.py | 47 +++++---- 13 files changed, 91 insertions(+), 201 deletions(-) diff --git a/src/parallax/models/kimi_vl.py b/src/parallax/models/kimi_vl.py index b6c2057f..f17434e5 100644 --- a/src/parallax/models/kimi_vl.py +++ b/src/parallax/models/kimi_vl.py @@ -9,13 +9,13 @@ from typing import Any, List, Optional import mlx.core as mx -from mlx import nn from mlx_lm.models.base import scaled_dot_product_attention # Import from mlx-vlm kimi_vl language module from mlx_vlm.models.kimi_vl.language import DeepseekV3Attention as MLXKimiVLAttention -from mlx_vlm.models.kimi_vl.language import DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer -from mlx_vlm.models.kimi_vl.language import DeepseekV3MLP, DeepseekV3MoE +from mlx_vlm.models.kimi_vl.language import ( + DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer, +) from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache from parallax.server.cache.base import BaseCache diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index ac8c5418..8209d9fd 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -111,20 +111,21 @@ def run_executor_process(args, shared_state=None, conn=None): """Run executor as a subprocess""" # Set rank to suppress logs on non-zero ranks # Must be called AFTER set_log_level to override the level - tp_rank = getattr(args, 'tp_rank', 0) - tp_size = getattr(args, 'tp_size', 1) - + tp_rank = getattr(args, "tp_rank", 0) + tp_size = getattr(args, "tp_size", 1) + # For non-zero ranks, suppress logs before any imports if tp_size > 1 and tp_rank != 0: import logging + logging.getLogger().setLevel(logging.CRITICAL + 1) - + set_log_level(args.log_level) - + # Now set rank properly (will re-suppress for non-zero ranks) if tp_size > 1: set_rank(tp_rank, enable_filter=True) - + executor = None try: executor = create_from_args(args, shared_state, conn) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index fc19a7dc..4090cdb8 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -159,18 +159,6 @@ def __init__( self.vlm_processor = None except Exception as e: logger.debug(f"AutoProcessor failed: {e}") - if not processor_loaded: - try: - # Must import torch first to avoid flex_attention import errors in transformers - from transformers import Qwen2VLProcessor - - self.vlm_processor = Qwen2VLProcessor.from_pretrained( - processor_path, trust_remote_code=True - ) - logger.info(f"Loaded VLM processor (Qwen2VLProcessor) for {self.model_type}") - processor_loaded = True - except Exception as e: - logger.debug(f"Qwen2VLProcessor failed: {e}") if not processor_loaded: logger.warning( diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index caa349e0..06d8d970 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -112,29 +112,25 @@ def __init__( if has_norm_in: self.norm_in = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps) - # Initialize vision components for VLM on first shard if self.is_vlm: logger.info( f"Initializing VLM components: vision_tower ({self.vision_config.model_type})" ) self.vision_tower = vision_tower_class(self.vision_config) - # Some VLMs (e.g., Qwen2-VL, Qwen3-VL) have the projector/merger built into VisionModel - # In these cases, multi_modal_projector_class can be None if multi_modal_projector_class is not None: - # Some projectors (e.g., KimiVL) need both vision_config and text_config - # Create a combined config object if the projector expects it try: self.multi_modal_projector = multi_modal_projector_class(config) except (TypeError, AttributeError): - # Projector expects a combined config with vision_config + text_config - combined_config = type("CombinedConfig", (), { - "vision_config": self.vision_config, - "text_config": config, - })() + combined_config = type( + "CombinedConfig", + (), + { + "vision_config": self.vision_config, + "text_config": config, + }, + )() self.multi_modal_projector = multi_modal_projector_class(combined_config) - logger.info( - "Initialized projector with combined vision+text config" - ) + logger.info("Initialized projector with combined vision+text config") else: self.multi_modal_projector = None logger.info( @@ -197,17 +193,10 @@ def get_input_embeddings( if not self.is_first_shard: raise ValueError("get_input_embeddings should only be called on the first shard") - # Get text embeddings inputs_embeds = self.embed_tokens(input_ids) - - # If no images or not a VLM, return text embeddings directly if pixel_values is None or not self.is_vlm: return InputEmbeddingsOutput(inputs_embeds=inputs_embeds) - - # Process vision features image_features = self._encode_images(pixel_values, **kwargs) - - # Merge image features with text embeddings final_embeds = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids ) @@ -220,20 +209,10 @@ def _encode_images( image_grid_thw: Optional[mx.array] = None, **kwargs, ) -> mx.array: - """Encode images through vision tower and projector. - - Args: - pixel_values: Image tensor, typically (batch, C, H, W) or (num_patches, C, H, W) - image_grid_thw: Grid size (T, H, W) for Qwen-VL models - **kwargs: Additional model-specific arguments - - Returns: - Projected image features ready to be merged with text embeddings - """ + """Encode images through vision tower and projector.""" if self.vision_tower is None: raise ValueError("Vision tower not initialized for this model") - # Check if this is a model that uses grid_thw for vision encoding model_type = getattr(self.vision_config, "model_type", "") if self.vision_config else "" is_qwen_vl = "qwen" in model_type.lower() and "vl" in model_type.lower() is_moonvit = model_type.lower() == "moonvit" @@ -243,74 +222,57 @@ def _encode_images( if hasattr(self.vision_tower, "patch_embed") and hasattr( self.vision_tower.patch_embed, "proj" ): - target_dtype = self.vision_tower.patch_embed.proj.weight.dtype - pixel_values = pixel_values.astype(target_dtype) + pixel_values = pixel_values.astype(self.vision_tower.patch_embed.proj.weight.dtype) else: pixel_values = pixel_values.astype(self.dtype) - # Get vision features from vision tower if uses_grid_thw and image_grid_thw is not None: if is_moonvit: - # KimiVL/MoonViT style: VisionModel expects NHWC input and grid_thw - # pixel_values may be NCHW from processor, convert to NHWC + # MoonViT (KimiVL) expects NHWC input if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: pixel_values = pixel_values.transpose(0, 2, 3, 1) vision_outputs = self.vision_tower( pixel_values, grid_thw=image_grid_thw, output_hidden_states=True ) else: - # Qwen-VL style: VisionModel(pixel_values, grid_thw) -> hidden_states - # No format conversion needed - Qwen-VL expects flat patches + # Qwen-VL expects flat patches vision_outputs = self.vision_tower(pixel_values, image_grid_thw) if isinstance(vision_outputs, tuple): - # First element is the merged hidden states (already projected by merger) selected_features = vision_outputs[0] elif isinstance(vision_outputs, list): - # KimiVL patch_merger returns a list of arrays selected_features = vision_outputs else: selected_features = vision_outputs else: - # Standard CLIP/SigLIP style - # Convert to vision tower expected format (typically NHWC for MLX) + # CLIP/SigLIP style: NCHW -> NHWC if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: - # NCHW -> NHWC pixel_values = pixel_values.transpose(0, 2, 3, 1) vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - # Handle different output formats if isinstance(vision_outputs, tuple): - # CLIP/SigLIP style: (pooler_output, last_hidden_state, hidden_states) if len(vision_outputs) >= 3: - hidden_states = vision_outputs[2] # All hidden states + hidden_states = vision_outputs[2] if isinstance(self.vision_feature_layer, int): selected_features = hidden_states[self.vision_feature_layer] if self.vision_feature_select_strategy == "default": - # Remove CLS token selected_features = selected_features[:, 1:] else: - # Multiple layers hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer] if self.vision_feature_select_strategy == "default": hs_pool = [hs[:, 1:] for hs in hs_pool] selected_features = mx.concatenate(hs_pool, axis=-1) else: - # Simple (pooler, hidden_state) output selected_features = vision_outputs[1] if self.vision_feature_select_strategy == "default": selected_features = selected_features[:, 1:] else: - # Direct hidden state output selected_features = vision_outputs - # Project to language model dimension if projector exists - # Qwen-VL models have projection built into VisionModel's merger if self.multi_modal_projector is not None: image_features = self.multi_modal_projector(selected_features) else: - # VisionModel already outputs projected features image_features = selected_features return image_features @@ -321,36 +283,19 @@ def _merge_input_ids_with_image_features( inputs_embeds: mx.array, input_ids: mx.array, ) -> mx.array: - """Merge image features into input embeddings at image token positions. - - This replaces placeholder tokens with actual image feature embeddings. - - Args: - image_features: (num_images, num_patches, hidden_dim) or (total_patches, hidden_dim) - inputs_embeds: (batch, seq_len, hidden_dim) Text embeddings - input_ids: (batch, seq_len) Token IDs for finding image positions - - Returns: - Merged embeddings with image features inserted at image token positions - """ + """Replace placeholder tokens with actual image feature embeddings.""" if self.image_token_index is None: logger.warning("image_token_index not set, cannot merge image features") return inputs_embeds batch_size, seq_len, hidden_dim = inputs_embeds.shape - - # Find positions of image tokens image_positions = input_ids == self.image_token_index - # Flatten image features if needed if image_features.ndim == 3: - # (num_images, num_patches, dim) -> (total_patches, dim) image_features = image_features.reshape(-1, image_features.shape[-1]) - # Cast image features to match embedding dtype image_features = image_features.astype(inputs_embeds.dtype) - # Process each batch item batch_outputs = [] feature_start_idx = 0 @@ -359,28 +304,23 @@ def _merge_input_ids_with_image_features( num_positions = int(mx.sum(batch_mask).item()) if num_positions > 0: - # Extract features for this batch batch_features = image_features[ feature_start_idx : feature_start_idx + num_positions ] if batch_features.shape[0] != num_positions: raise ValueError( - f"Number of image token positions ({num_positions}) does not match " - f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}" + f"Image token positions ({num_positions}) does not match " + f"image features ({batch_features.shape[0]}) for batch {batch_idx}" ) - # Create indices for gathering cumsum = mx.cumsum(batch_mask.astype(mx.int32)) feature_indices = mx.where(batch_mask, cumsum - 1, 0) - - # Gather features and create merged output gathered_features = batch_features[feature_indices] batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1) batch_output = mx.where( batch_mask_expanded, gathered_features, inputs_embeds[batch_idx] ) - feature_start_idx += num_positions else: batch_output = inputs_embeds[batch_idx] diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index 2ecce2aa..434163bc 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -78,27 +78,10 @@ class VLMInputs: receive pre-computed image embeddings merged into hidden_states. """ - # Preprocessed image tensor, shape varies by model: - # - LLaVA: (num_images, C, H, W) or (num_patches, C, patch_H, patch_W) - # - Qwen-VL: (num_patches, C, patch_H, patch_W) with temporal dim for video - # Can be numpy array or PyTorch tensor - mx.array() can convert both pixel_values: Optional[Any] = None - - # For models with dynamic resolution (e.g., Qwen2-VL): - # Tuple of (temporal, height, width) grid sizes for each image - # Shape: (num_images, 3) where each row is (t, h, w) - # Can be numpy array or PyTorch tensor image_grid_thw: Optional[Any] = None - - # Number of image tokens per image (for variable-length image tokens) image_token_counts: Optional[List[int]] = None - - # Original image sizes before preprocessing (height, width) - # Useful for models that need aspect ratio information image_sizes: Optional[List[tuple]] = None - - # Whether images have been processed into embeddings - # (set to True after first peer processes images) images_processed: bool = False def has_images(self) -> bool: @@ -410,16 +393,14 @@ def from_initial_request( else: next_token_id = initial_request.output_ids[-1] - # For VLM: after first peer processes images, mark as processed - # and don't pass pixel_values to subsequent peers (only metadata) vlm_inputs = None if initial_request.vlm_inputs is not None: vlm_inputs = VLMInputs( - pixel_values=None, # Don't pass raw pixels to next peers + pixel_values=None, image_grid_thw=initial_request.vlm_inputs.image_grid_thw, image_token_counts=initial_request.vlm_inputs.image_token_counts, image_sizes=initial_request.vlm_inputs.image_sizes, - images_processed=True, # Mark as processed by first peer + images_processed=True, ) return IntermediateRequest( diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index e820df02..ca359d24 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -29,8 +29,7 @@ "kimi_k2": "mlx_lm.models.deepseek_v3", } -# VLM models that need to use text_config for ModelArgs -# Format: model_type -> base_model_type (for loading ModelArgs from mlx_lm) + VLM_TEXT_CONFIG_MAP = { "qwen3_vl": "qwen3", "qwen2_vl": "qwen2", @@ -38,9 +37,6 @@ "kimi_vl": "deepseek_v3", } -# VLM models that need special handling (have separate projector class) -# Format: model_type -> (projector_module_path, projector_class_name) -# Default: VisionModel from mlx_vlm.models.{model_type}, no separate projector VLM_SPECIAL_PROJECTOR_MAP = { "llava": ("mlx_vlm.models.llava.llava", "LlavaMultiModalProjector"), "llava_next": ("mlx_vlm.models.llava_next.llava_next", "LlavaMultiModalProjector"), @@ -67,12 +63,10 @@ def _get_vlm_classes( return None, None, None try: - # Default: load VisionModel from mlx_vlm.models.{model_type} vision_module_path = f"mlx_vlm.models.{model_type}" vision_module = importlib.import_module(vision_module_path) vision_tower_class = getattr(vision_module, "VisionModel") - # Check if this model needs a separate projector projector_class = None if model_type in VLM_SPECIAL_PROJECTOR_MAP: proj_module_path, proj_class_name = VLM_SPECIAL_PROJECTOR_MAP[model_type] @@ -324,17 +318,13 @@ def load( if not model_type: raise ValueError("model_type not found in config.json") - # For VLM models, use text_config for ModelArgs and map to base model type config_for_args = config model_class_type = model_type if model_type in VLM_TEXT_CONFIG_MAP: - # VLM models have text_config containing the language model config text_config = config.get("text_config", {}) if text_config: - # Merge text_config into a flat config for ModelArgs config_for_args = {**config, **text_config} - # Also get num_hidden_layers from text_config if not in root if "num_hidden_layers" not in config and "num_hidden_layers" in text_config: config["num_hidden_layers"] = text_config["num_hidden_layers"] model_class_type = VLM_TEXT_CONFIG_MAP[model_type] @@ -370,15 +360,13 @@ def load( else: # If it's already a clean name or a local path (take basename) model_id = pathlib.Path(model_id).name - # Check for VLM model and get vision classes vision_tower_class, projector_class, vision_config = _get_vlm_classes(model_type, config) is_vlm = vision_config is not None and vision_tower_class is not None - # Get VLM-specific config parameters image_token_index = ( config.get("image_token_index") or config.get("image_token_id") - or config.get("media_placeholder_token_id") # KimiVL uses this name + or config.get("media_placeholder_token_id") ) vision_feature_layer = config.get("vision_feature_layer", -2) vision_feature_select_strategy = config.get("vision_feature_select_strategy", "default") @@ -436,18 +424,12 @@ def load( # loading only what we need. shard_weights = {} - # Layer key prefixes to check (in order of priority) - # Different model formats use different key prefixes: - # - language_model.model.layers.X (mlx-vlm converted format) - # - model.language_model.layers.X (HuggingFace VLM format) - # - model.layers.X (Standard LLM format) layer_key_prefixes = [ ("language_model.model.layers.", 3), # mlx-vlm style: parts[3] is layer index ("model.language_model.layers.", 3), # HF VLM style: parts[3] is layer index ("model.layers.", 2), # Standard style: parts[2] is layer index ] - # VLM weight prefixes to load on first shard vlm_weight_prefixes = [ "vision_tower.", "vision_model.", @@ -456,7 +438,6 @@ def load( "mm_projector.", ] - # Get tie_word_embeddings config (check both root and text_config for VLM) tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) for file_idx, wf in enumerate(weight_files): @@ -469,16 +450,9 @@ def load( is_needed = False remapped_key = None - # Check if the key belongs to the shard and remap it - # Embeddings: Various formats: - # - language_model.model.embed_tokens.* (mlx-vlm converted) - # - model.language_model.embed_tokens.* (HF VLM) - # - model.embed_tokens.* (standard) if model_shard.is_first_shard and "embed_tokens" in key: is_needed = True - # Remap to just embed_tokens.* if "language_model.model.embed_tokens" in key: - # mlx-vlm format: language_model.model.embed_tokens.weight -> embed_tokens.weight remapped_key = key.replace("language_model.model.", "") elif "language_model.embed_tokens" in key: remapped_key = key.split("language_model.")[-1] @@ -491,10 +465,6 @@ def load( shard_weights[lm_head_key] = f[key] elif model_shard.is_last_shard: - # Final norm: Various formats - # - language_model.model.norm.* (mlx-vlm converted) - # - model.language_model.norm.* (HF VLM) - # - model.norm.* (standard) if ".norm." in key or key.endswith(".norm.weight"): is_final_norm = ( "language_model.model.norm" in key @@ -511,7 +481,6 @@ def load( remapped_key = key.replace("model.", "", 1) if "lm_head" in key: is_needed = True - # Handle language_model.lm_head.* format if key.startswith("language_model."): remapped_key = key.replace("language_model.", "") else: @@ -538,14 +507,11 @@ def load( is_needed = True remapped_key = key break - # Handle model.vision_tower.*, model.visual.* style keys if key.startswith(f"model.{prefix}"): is_needed = True - # Keep as vision_tower.* or visual.* (remove model. prefix) remapped_key = key.replace("model.", "", 1) break - # Check layer keys with multiple prefix patterns if not is_needed: for layer_prefix, layer_idx_pos in layer_key_prefixes: if layer_prefix in key: diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index fd5a6cb9..f0edc61c 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -80,7 +80,9 @@ def __init__( if num_hidden_layers is None: raise ValueError("num_hidden_layers is required but not found in model config") is_vlm = config_accessor.is_vlm - set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers, is_vlm=is_vlm) + set_layer_range_for_filtering( + pp_start_layer, pp_end_layer, num_hidden_layers, is_vlm=is_vlm + ) super().__init__( model_config=model_config, diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index f1d9e318..261bbc20 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -12,8 +12,8 @@ def set_layer_range_for_filtering( - pp_start_layer: int, - pp_end_layer: int, + pp_start_layer: int, + pp_end_layer: int, num_hidden_layers: int, is_vlm: bool = False, ): diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 2cccc1dc..8bc28e04 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -73,12 +73,12 @@ def process_images( # Kimi K2.5 requires special handling: # 1. Expand image tokens based on actual image size # 2. Use 'medias' parameter instead of 'images' - + # The image token for Kimi K2.5 is <|media_pad|> image_token = "<|media_pad|>" - + # Expand single image tokens to the correct number of tokens - if image_token in input_text and hasattr(processor, 'media_processor'): + if image_token in input_text and hasattr(processor, "media_processor"): parts = input_text.split(image_token) result = [parts[0]] for i, (image, part) in enumerate(zip(images, parts[1:])): @@ -89,12 +89,14 @@ def process_images( ) logger.debug(f"Kimi K2.5: Image {i} expanded to {num_tokens} tokens") except Exception as e: - logger.warning(f"Failed to calculate media tokens for image {i}: {e}, using 1") + logger.warning( + f"Failed to calculate media tokens for image {i}: {e}, using 1" + ) num_tokens = 1 result.append(image_token * num_tokens + part) input_text = "".join(result) logger.debug(f"Kimi K2.5: Expanded input_text length: {len(input_text)}") - + # Kimi K2.5 requires 'medias' parameter with specific format medias = [{"type": "image", "image": img} for img in images] inputs = processor(medias=medias, text=input_text, return_tensors="pt") @@ -109,9 +111,9 @@ def process_images( # Debug: Log all keys returned by the processor logger.debug(f"Processor output keys: {list(inputs.keys())}") for key, value in inputs.items(): - if hasattr(value, 'shape'): + if hasattr(value, "shape"): logger.debug(f" {key}: shape={value.shape}, dtype={value.dtype}") - elif hasattr(value, '__len__'): + elif hasattr(value, "__len__"): logger.debug(f" {key}: len={len(value)}, type={type(value)}") else: logger.debug(f" {key}: {type(value)}") @@ -120,9 +122,13 @@ def process_images( if pixel_values is None: logger.error("Processor output missing pixel_values") return [], None, [] - - logger.debug(f"pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}, device: {pixel_values.device}") - logger.debug(f"pixel_values stats: min={pixel_values.min().item():.4f}, max={pixel_values.max().item():.4f}, mean={pixel_values.mean().item():.4f}") + + logger.debug( + f"pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}, device: {pixel_values.device}" + ) + logger.debug( + f"pixel_values stats: min={pixel_values.min().item():.4f}, max={pixel_values.max().item():.4f}, mean={pixel_values.mean().item():.4f}" + ) expanded_input_ids = inputs.get("input_ids") if expanded_input_ids is not None: @@ -136,7 +142,7 @@ def process_images( is_kimi_k25 = processor_class_name == "KimiK25Processor" image_grid_thw = inputs.get("image_grid_thw") or inputs.get("grid_thws") image_sizes = inputs.get("image_sizes") - + logger.debug(f"image_grid_thw: {image_grid_thw}, is_kimi_k25: {is_kimi_k25}") # Determine the correct field name for grid data diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index b0265498..f01a7bc3 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -22,11 +22,6 @@ def should_include_weight_key( if is_first_shard and "embed_tokens" in key: return True - # VLM: Include vision components on first shard - # Handles various naming conventions: - # - vision_tower.*, model.vision_tower.* - # - model.visual.*, visual.* (Qwen-VL style) - # - multi_modal_projector.*, mm_projector.* if is_first_shard and is_vlm: vlm_prefixes = [ "vision_tower", @@ -39,10 +34,7 @@ def should_include_weight_key( if key.startswith(prefix) or key.startswith(f"model.{prefix}"): return True - # Final norm and lm_head on last shard - # Handles: model.norm, model.language_model.norm, language_model.model.norm, lm_head if is_last_shard: - # Check for final norm (not layer norms) if ("lm_head" in key) or ( (".norm." in key or key.endswith(".norm.weight")) and "layers" not in key ): @@ -50,8 +42,6 @@ def should_include_weight_key( if tie_word_embeddings and "embed_tokens" in key: return True - # Transformer layers - check layer index - # Handles: model.layers.X, model.language_model.layers.X, language_model.model.layers.X if "layers." in key: parts = key.split(".") for i, part in enumerate(parts): @@ -193,7 +183,9 @@ def determine_needed_weight_files_for_download( logger.debug("weight_map is empty in index file") return [] - tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) if config else False + tie_word_embeddings = ( + get_config_value(config, "tie_word_embeddings", False) if config else False + ) needed_files: Set[str] = set() diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index a713328a..e210cfcc 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -409,6 +409,7 @@ def initialize_vllm_model_runner( # local weight file reads when loading a partial layer shard. try: from parallax.utils.config_utils import is_vlm_model + is_vlm = is_vlm_model(config) set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers, is_vlm=is_vlm) apply_weight_loader_filter_patch() diff --git a/src/parallax_utils/logging_config.py b/src/parallax_utils/logging_config.py index 81962a43..ba1c1fd8 100644 --- a/src/parallax_utils/logging_config.py +++ b/src/parallax_utils/logging_config.py @@ -145,11 +145,11 @@ def use_parallax_log_handler(for_root: bool = True): def set_rank(rank: int, enable_filter: bool = True): """ Set the current process rank for log filtering. - + When rank filtering is enabled, only rank 0 will print logs. This is useful for multi-GPU (TP/DP) scenarios where you want to avoid duplicate log messages from all processes. - + Args: rank: The rank of the current process (0 for master). enable_filter: If True, only rank 0 will print logs. @@ -157,7 +157,7 @@ def set_rank(rank: int, enable_filter: bool = True): global _current_rank, _rank_filter_enabled _current_rank = rank _rank_filter_enabled = enable_filter - + if enable_filter and rank != 0: # Disable logging for non-zero ranks by setting to a high level logging.getLogger().setLevel(logging.CRITICAL + 1) diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index fbec2c82..b7bbbc33 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -32,6 +32,7 @@ def make_tool_state( # ---- Parsers mimicking common model formats ---- + def json_tool_parser(tool_text: str, tools): """Parser for JSON-formatted tool calls (e.g. Qwen, GLM).""" parsed = json.loads(tool_text) @@ -72,11 +73,13 @@ def multi_tool_parser(tool_text: str, tools): if not part.endswith("}"): part = part + "}" parsed = json.loads(part) - calls.append({ - "id": parsed.get("id", "call_0"), - "name": parsed["name"], - "arguments": parsed["arguments"], - }) + calls.append( + { + "id": parsed.get("id", "call_0"), + "name": parsed["name"], + "arguments": parsed["arguments"], + } + ) return calls @@ -170,7 +173,9 @@ def test_json_tool_call(self): tools=SAMPLE_TOOLS, ) - segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) text, tool_calls = state.extract_from_segment(segment) self.assertEqual(text, "") @@ -193,12 +198,12 @@ def test_xml_tool_call(self): ) segment = ( - '<|tool_start|>' + "<|tool_start|>" '' '12345' '67890' - '' - '<|tool_end|>' + "" + "<|tool_end|>" ) text, tool_calls = state.extract_from_segment(segment) @@ -267,7 +272,9 @@ def test_inactive_state_passes_through(self): stream=False, ) - segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) text, tool_calls = state.extract_from_segment(segment) self.assertEqual(text, segment) @@ -300,9 +307,7 @@ def test_tool_call_split_across_segments(self): ) # First segment: start marker + partial content - text1, calls1 = state.extract_from_segment( - '{"name": "get_weather",' - ) + text1, calls1 = state.extract_from_segment('{"name": "get_weather",') self.assertEqual(text1, "") self.assertEqual(calls1, []) self.assertTrue(state.in_tool_call) @@ -366,7 +371,9 @@ def test_stream_mode_index_increments(self): ) # First tool call - segment1 = '{"name": "get_weather", "arguments": {"location": "NYC"}}' + segment1 = ( + '{"name": "get_weather", "arguments": {"location": "NYC"}}' + ) _, calls1 = state.extract_from_segment(segment1) self.assertEqual(calls1[0]["index"], 0) @@ -446,7 +453,9 @@ def test_made_tool_call_flag(self): self.assertFalse(state.made_tool_call) - segment = '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) state.extract_from_segment(segment) self.assertTrue(state.made_tool_call) @@ -474,7 +483,9 @@ def test_string_arguments_serialized_to_json(self): tools=SAMPLE_TOOLS, ) - segment = '{"name": "get_weather", "arguments": {"location": "北京"}}' + segment = ( + '{"name": "get_weather", "arguments": {"location": "北京"}}' + ) _, tool_calls = state.extract_from_segment(segment) # arguments should be a JSON string (not a dict) in the output @@ -491,7 +502,9 @@ def test_unicode_in_arguments(self): tools=SAMPLE_TOOLS, ) - segment = '{"name": "get_weather", "arguments": {"location": "東京"}}' + segment = ( + '{"name": "get_weather", "arguments": {"location": "東京"}}' + ) _, tool_calls = state.extract_from_segment(segment) args_str = tool_calls[0]["function"]["arguments"] From ed80339ce9780b47db69a8ceb5061603ed7884cd Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Sat, 7 Feb 2026 19:59:15 +0800 Subject: [PATCH 25/36] update --- docs/user_guide/install.md | 2 +- pyproject.toml | 6 +- src/parallax/server/executor/base_executor.py | 55 +++++++++++++++++++ src/parallax/server/scheduler.py | 9 +++ src/parallax/utils/tokenizer_utils.py | 30 ++++++++++ 5 files changed, 98 insertions(+), 4 deletions(-) diff --git a/docs/user_guide/install.md b/docs/user_guide/install.md index ecfbf16d..7c011a95 100644 --- a/docs/user_guide/install.md +++ b/docs/user_guide/install.md @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio ```sh git clone https://github.com/GradientHQ/parallax.git cd parallax -pip install -e ".[gpu]" && pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps +pip install -e ".[gpu]" && pip install mlx-lm==0.30.6 --no-deps ``` #### For macOS (Apple silicon): diff --git a/pyproject.toml b/pyproject.toml index 1ba0f035..27a44d8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,9 +55,9 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", "mlx-lm==0.28.4", - "mlx[cpu]==0.30.0", - # due to transformers version conflict, we need to install mlx-lm and mlx separately - # pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps + "mlx[cpu]==0.30.4", + # due to transformers version conflict, we need to install mlx-lm separately + # pip install mlx-lm==0.30.6 --no-deps ] vllm = [ diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 0c01f3e5..f3c85e71 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -158,6 +158,7 @@ def __init__( self.eos_token_id = self._config_accessor.get_eos_token_id() + self._augment_eos_with_im_end() # Build multimodal config (only meaningful for VLM models) self.mm_config = self._config_accessor.build_mm_config() @@ -628,6 +629,37 @@ def shutdown(self): logger.debug("Executor shutdown complete.") + def _augment_eos_with_im_end(self): + """Add ``<|im_end|>`` to the EOS token list when it is present in the + vocabulary but missing from the configured ``eos_token_id``. + + Many chat models (Kimi-K2.5, Qwen, etc.) use ``<|im_end|>`` as the + turn-ending token, yet their ``config.json`` only lists ``[EOS]`` as + the EOS token. Without this augmentation the scheduler will never + detect end-of-turn and generation will run until ``max_tokens``. + """ + _get_vocab = getattr(self.tokenizer, "get_vocab", None) + vocab = _get_vocab() if _get_vocab else {} + im_end_id = vocab.get("<|im_end|>") + if im_end_id is None: + return + + # Normalise eos_token_id to a list for easy comparison + if self.eos_token_id is None: + self.eos_token_id = [im_end_id] + logger.info(f"Set eos_token_id to [{im_end_id}] (<|im_end|>)") + elif isinstance(self.eos_token_id, list): + if im_end_id not in self.eos_token_id: + self.eos_token_id.append(im_end_id) + logger.info(f"Added <|im_end|> (id={im_end_id}) to eos_token_id list") + elif isinstance(self.eos_token_id, int): + if self.eos_token_id != im_end_id: + self.eos_token_id = [self.eos_token_id, im_end_id] + logger.info( + f"Expanded eos_token_id to {self.eos_token_id} " + f"(added <|im_end|> id={im_end_id})" + ) + def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> list: """Process a text-only request using the tokenizer.""" if self.tokenizer.chat_template: @@ -748,6 +780,29 @@ def _handle_raw_request(self, raw_request: Dict): if "ignore_eos" in raw_sampling_params: sampling_params.ignore_eos = raw_sampling_params["ignore_eos"] + # Also read OpenAI-style top-level sampling parameters as fallback + if "temperature" in raw_request and raw_sampling_params is None: + sampling_params.temperature = raw_request["temperature"] + if sampling_params.temperature == 0.0: + sampling_params.temperature = 1.0 + sampling_params.top_k = 1 + if "top_p" in raw_request and raw_sampling_params is None: + sampling_params.top_p = raw_request["top_p"] + + # When tools are present, add tool-call-related stop token IDs so the + # scheduler halts generation at the tool-call boundary instead of + # running until max_tokens. + tools = raw_request.get("tools") + if tools and self.tokenizer is not None: + from parallax.utils.tokenizer_utils import get_tool_call_stop_token_ids + + tool_stop_ids = get_tool_call_stop_token_ids(self.tokenizer) + if tool_stop_ids: + if sampling_params.stop_token_ids is None: + sampling_params.stop_token_ids = set() + sampling_params.stop_token_ids.update(tool_stop_ids) + logger.debug(f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}") + req = InitialRequest( request_id=rid, output_ids=None, diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 69ab5fbd..29b335c4 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -224,6 +224,15 @@ def check_and_update_request_status(self, request: InitialRequest) -> bool: ): request.update_status(RequestStatus.FINISHED_EOS) finished = True + elif ( + not finished + and not request.sampling_params.ignore_eos + and request.sampling_params.stop_token_ids + and last_token_id is not None + and last_token_id in request.sampling_params.stop_token_ids + ): + request.update_status(RequestStatus.FINISHED_EOS) + finished = True elif request.output_length >= request.max_new_tokens: request.update_status(RequestStatus.FINISHED_MAX_LENGTH) finished = True diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index 1d5f9d6c..1992516e 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -128,6 +128,36 @@ def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=No return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) +def get_tool_call_stop_token_ids(tokenizer) -> List[int]: + """Return token IDs that should act as *stop tokens* for tool call generation. + + When the model generates one of these tokens the scheduler should treat it + as end-of-sequence so that the HTTP handler can inspect the generated text + and extract tool calls. + + Note: tool call *parsing* (``has_tool_calling``, ``tool_parser``, etc.) is + handled automatically by the updated ``mlx-lm`` ``TokenizerWrapper``. + This function only provides the stop-token IDs that the parallax scheduler + needs to halt generation at tool-call boundaries. + """ + stop_ids: List[int] = [] + _get_vocab = getattr(tokenizer, "get_vocab", None) + vocab = _get_vocab() if _get_vocab else {} + + # Markers whose token IDs should halt generation + markers = [ + "<|tool_calls_section_end|>", # Kimi K2 / K2.5 + "<|im_end|>", # common chat turn-end token + ] + + for marker in markers: + token_id = vocab.get(marker) + if token_id is not None: + stop_ids.append(token_id) + + return list(set(stop_ids)) + + @dataclass class ToolCallState: has_tool_calling: bool From 44c2234f671f273eca1c93ca3bbbab949a4da3ce Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 13:16:15 +0800 Subject: [PATCH 26/36] check kimi2.5 sampling params --- src/parallax/server/http_server.py | 76 ++++++++++++++++++++ src/parallax/server/node_chat_http_server.py | 10 +++ 2 files changed, 86 insertions(+) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index ad3ca173..f641f9d5 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -46,6 +46,77 @@ logger = get_logger(__name__) +# --------------------------------------------------------------------------- +# Kimi-K2.5 immutable parameter constraints +# These parameters are locked to specific values and the API must reject +# any request that attempts to override them with different values. +# See: https://github.com/MoonshotAI/Kimi-Vendor-Verifier +# --------------------------------------------------------------------------- +_KIMI_K25_CONSTRAINTS = { + "think": { + "temperature": 1.0, + "top_p": 0.95, + "presence_penalty": 0, + "frequency_penalty": 0, + "n": 1, + }, + "non_think": { + "temperature": 0.6, + "top_p": 0.95, + "presence_penalty": 0, + "frequency_penalty": 0, + "n": 1, + }, +} + +# Model names that should enforce immutable parameter constraints. +# Add more model name patterns as needed. +_KIMI_K25_MODEL_KEYWORDS = ("kimi-k2.5", "kimi_k2.5", "kimi-k2-5") + + +def _is_kimi_k25_model(model_name: str) -> bool: + """Check if the model name matches Kimi-K2.5.""" + name_lower = model_name.lower() + return any(kw in name_lower for kw in _KIMI_K25_MODEL_KEYWORDS) + + +def _detect_thinking_mode(request_json: dict) -> bool: + """Detect whether the request uses thinking mode.""" + # opensource format: chat_template_kwargs.thinking + ctk = request_json.get("chat_template_kwargs", {}) + if isinstance(ctk, dict) and ctk.get("thinking"): + return True + # kimi SaaS format: thinking.type == "enabled" + thinking = request_json.get("thinking", {}) + if isinstance(thinking, dict) and thinking.get("type") == "enabled": + return True + return False + + +def validate_kimi_k25_params(request_json: dict) -> Optional[str]: + """Validate immutable parameters for Kimi-K2.5 models. + + Returns an error message string if validation fails, or None if OK. + """ + model = request_json.get("model", "") + if not _is_kimi_k25_model(model): + return None # Not a Kimi-K2.5 model, skip + + is_thinking = _detect_thinking_mode(request_json) + mode = "think" if is_thinking else "non_think" + constraints = _KIMI_K25_CONSTRAINTS[mode] + mode_label = "thinking" if is_thinking else "non-thinking" + + for param, expected in constraints.items(): + if param in request_json and request_json[param] != expected: + return ( + f"Invalid parameter for {model} ({mode_label} mode): " + f"'{param}' must be {expected}, got {request_json[param]}" + ) + + return None + + def get_exception_traceback(): """Traceback function to handle asyncio function errors""" etype, value, tb = sys.exc_info() @@ -494,6 +565,11 @@ async def v1_chat_completions(raw_request: fastapi.Request): except Exception as e: return create_error_response("Invalid request body, error: ", str(e)) + # Validate immutable parameter constraints (e.g. Kimi-K2.5) + param_error = validate_kimi_k25_params(request_json) + if param_error is not None: + return create_error_response(param_error, "BadRequestError", HTTPStatus.BAD_REQUEST) + # Check if request_json has "rid", otherwise generate new one request_id = request_json.get("rid") if request_id is None: diff --git a/src/parallax/server/node_chat_http_server.py b/src/parallax/server/node_chat_http_server.py index 6ecbc283..f97a32dc 100644 --- a/src/parallax/server/node_chat_http_server.py +++ b/src/parallax/server/node_chat_http_server.py @@ -15,6 +15,7 @@ from backend.server.rpc_connection_handler import RPCConnectionHandler from parallax_utils.file_util import get_project_root from parallax_utils.logging_config import get_logger +from parallax.server.http_server import validate_kimi_k25_params logger = get_logger(__name__) @@ -43,6 +44,15 @@ async def get_cluster_status(): async def openai_v1_chat_completions(raw_request: Request): """OpenAI v1/chat/complete post function""" request_data = await raw_request.json() + + # Validate immutable parameter constraints (e.g. Kimi-K2.5) + param_error = validate_kimi_k25_params(request_data) + if param_error is not None: + return JSONResponse( + content={"error": {"message": param_error, "type": "BadRequestError", "code": 400}}, + status_code=400, + ) + request_id = uuid.uuid4() received_ts = time.time() return await v1_chat_completions(request_data, request_id, received_ts) From 967b4a03b392da98da445efc2375af9f5a6ef8b7 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 15:29:43 +0800 Subject: [PATCH 27/36] moiify create time --- src/parallax/server/http_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index f641f9d5..3657c180 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -147,7 +147,7 @@ class HTTPRequestInfo: finish_reason: str = None object: str = "chat.completion" model: str = "default" - create_time: float = 0.0 + create_time: int = 0 update_time: float = 0.0 logprobs: float = None matched_stop: int = None @@ -216,7 +216,7 @@ def create_request(self, request: Dict): return_probs = request.get("return_probs", False) # Check if probs requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) - create_time = time.time() + create_time = int(time.time()) update_time = create_time request_info = HTTPRequestInfo( id=rid, From 7460d8906391abe73a4d274ca37987322ddd5bb1 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 15:34:33 +0800 Subject: [PATCH 28/36] modify maxtokens limit --- src/parallax/server/executor/base_executor.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index f3c85e71..dff8a6ea 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -158,7 +158,12 @@ def __init__( self.eos_token_id = self._config_accessor.get_eos_token_id() + # Ensure <|im_end|> is in the EOS token list. Some models (e.g. + # Kimi-K2.5) use <|im_end|> to end assistant turns but only list + # [EOS] in config.json. Without this the scheduler will never + # detect end-of-turn and generation runs until max_tokens. self._augment_eos_with_im_end() + # Build multimodal config (only meaningful for VLM models) self.mm_config = self._config_accessor.build_mm_config() @@ -744,19 +749,22 @@ def _handle_raw_request(self, raw_request: Dict): max_new_tokens = raw_request.get("max_tokens", 2048) input_token_num = len(prompt) if input_token_num + max_new_tokens >= max_seq_len: - logger.warning( - f"Input token length {input_token_num} + max_new_tokens {max_new_tokens} exceeds max_sequence_length {max_seq_len}." - ) - if max_new_tokens > 2048: + available_new_tokens = max_seq_len - input_token_num + if available_new_tokens <= 0: logger.warning( - f"max_new_tokens {max_new_tokens} is too large, reduce to 2048 tokens." + f"Input token length {input_token_num} already exceeds " + f"max_sequence_length {max_seq_len}. " + f"Truncating input to keep last {max_seq_len - 1} tokens." ) - max_new_tokens = 2048 - if input_token_num + max_new_tokens >= max_seq_len: + prompt = prompt[-(max_seq_len - 1) :] + max_new_tokens = 1 + else: logger.warning( - f"Trunc input prompt, keep last {max_seq_len - max_new_tokens} tokens" + f"Input token length {input_token_num} + max_new_tokens {max_new_tokens} " + f"exceeds max_sequence_length {max_seq_len}. " + f"Reducing max_new_tokens to {available_new_tokens}." ) - prompt = prompt[-(max_seq_len - max_new_tokens) :] + max_new_tokens = available_new_tokens max_total_length = len(prompt) + max_new_tokens logger.debug(f"Final max_new_tokens for request ID {rid}: {max_new_tokens}") @@ -801,7 +809,9 @@ def _handle_raw_request(self, raw_request: Dict): if sampling_params.stop_token_ids is None: sampling_params.stop_token_ids = set() sampling_params.stop_token_ids.update(tool_stop_ids) - logger.debug(f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}") + logger.debug( + f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}" + ) req = InitialRequest( request_id=rid, From 726dfb892fe4245cabcf306c597173e1763467ad Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 16:41:58 +0800 Subject: [PATCH 29/36] fix pre-commit --- src/parallax/server/executor/base_executor.py | 4 +--- src/parallax/server/node_chat_http_server.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index dff8a6ea..08a0fcf8 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -809,9 +809,7 @@ def _handle_raw_request(self, raw_request: Dict): if sampling_params.stop_token_ids is None: sampling_params.stop_token_ids = set() sampling_params.stop_token_ids.update(tool_stop_ids) - logger.debug( - f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}" - ) + logger.debug(f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}") req = InitialRequest( request_id=rid, diff --git a/src/parallax/server/node_chat_http_server.py b/src/parallax/server/node_chat_http_server.py index f97a32dc..6b8b6ea6 100644 --- a/src/parallax/server/node_chat_http_server.py +++ b/src/parallax/server/node_chat_http_server.py @@ -13,9 +13,9 @@ from starlette.datastructures import State from backend.server.rpc_connection_handler import RPCConnectionHandler +from parallax.server.http_server import validate_kimi_k25_params from parallax_utils.file_util import get_project_root from parallax_utils.logging_config import get_logger -from parallax.server.http_server import validate_kimi_k25_params logger = get_logger(__name__) From fc2c34942ce944f3a6f3aa081c222662dfac1acd Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 16:57:19 +0800 Subject: [PATCH 30/36] update --- src/parallax/sglang/multimodal_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py index 8bc28e04..d604d5a0 100644 --- a/src/parallax/sglang/multimodal_utils.py +++ b/src/parallax/sglang/multimodal_utils.py @@ -140,7 +140,9 @@ def process_images( # Handle different field names: Kimi K2.5 uses 'grid_thws', others use 'image_grid_thw' is_kimi_k25 = processor_class_name == "KimiK25Processor" - image_grid_thw = inputs.get("image_grid_thw") or inputs.get("grid_thws") + image_grid_thw = inputs.get("image_grid_thw") + if image_grid_thw is None: + image_grid_thw = inputs.get("grid_thws") image_sizes = inputs.get("image_sizes") logger.debug(f"image_grid_thw: {image_grid_thw}, is_kimi_k25: {is_kimi_k25}") From 6ae3863a40e07328f97e3089bda561fea6a10ac3 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 18:12:10 +0800 Subject: [PATCH 31/36] modify tokenizer --- src/parallax/utils/tokenizer_utils.py | 46 +++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index 1992516e..b6107fc7 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -3,12 +3,15 @@ """ import json +import logging import uuid from dataclasses import dataclass from functools import partial from json import JSONDecodeError from typing import Any, Callable, Dict, List, Optional, Tuple +logger = logging.getLogger(__name__) + from mlx_lm.tokenizer_utils import ( BPEStreamingDetokenizer, NaiveStreamingDetokenizer, @@ -147,7 +150,7 @@ def get_tool_call_stop_token_ids(tokenizer) -> List[int]: # Markers whose token IDs should halt generation markers = [ "<|tool_calls_section_end|>", # Kimi K2 / K2.5 - "<|im_end|>", # common chat turn-end token + "<|im_end|>", # common chat turn-end token ] for marker in markers: @@ -208,14 +211,51 @@ def _format_tool_call(self, tool_call: Dict[str, Any]): self.tool_call_idx += 1 return out + def _get_available_tool_names(self) -> Optional[set]: + """Extract the set of valid function names from the tools list.""" + if not self.tools: + return None + names = set() + for t in self.tools: + try: + names.add(t["function"]["name"]) + except (KeyError, TypeError): + continue + return names if names else None + + def _validate_tool_name(self, tool_call: Dict[str, Any], valid_names: Optional[set]) -> bool: + """Check if a parsed tool call's function name is in the available tools.""" + if valid_names is None: + return True # No tools list to validate against + name = tool_call.get("name", "") + if name not in valid_names: + logger.warning( + f"Tool call rejected: function '{name}' not in available tools {valid_names}" + ) + return False + return True + def _parse_tool_text(self, tool_text: str) -> Tuple[List[Dict[str, Any]], Optional[str]]: try: parsed = self.tool_parser(tool_text, self.tools) except Exception: fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" return [], fallback_text + + valid_names = self._get_available_tool_names() + if isinstance(parsed, list): - return [self._format_tool_call(tc) for tc in parsed], None + # Filter out tool calls with hallucinated function names + valid_calls = [tc for tc in parsed if self._validate_tool_name(tc, valid_names)] + if not valid_calls: + fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" + return [], fallback_text + return [self._format_tool_call(tc) for tc in valid_calls], None + + # Single tool call + if not self._validate_tool_name(parsed, valid_names): + fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" + return [], fallback_text return [self._format_tool_call(parsed)], None def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]]: @@ -233,7 +273,6 @@ def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]] if start_pos > idx: output_chunks.append(segment[idx:start_pos]) self.in_tool_call = True - self.made_tool_call = True self.tool_text = "" idx = start_pos + len(self.tool_call_start) else: @@ -245,6 +284,7 @@ def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]] parsed_calls, fallback_text = self._parse_tool_text(self.tool_text) if parsed_calls: new_tool_calls.extend(parsed_calls) + self.made_tool_call = True if fallback_text: output_chunks.append(fallback_text) self.tool_text = "" From 924f256ecdf26da86cc7ddfceee241114db8d516 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 9 Feb 2026 18:31:51 +0800 Subject: [PATCH 32/36] update --- src/parallax/server/executor/sglang_executor.py | 16 ++++++++++++++++ src/parallax/server/http_server.py | 4 ++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index a8584bb8..4a7a4e01 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -547,6 +547,15 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A # Pre-check: Verify KV cache has enough space for prefill total_tokens_needed = sum(req.total_length for req in batched_requests) + try: + available = self.model_runner.token_to_kv_pool_allocator.available_size() + logger.debug( + f"[KV Cache Prefill] available={available}, needed={total_tokens_needed}, " + f"batch_size={batch_size}, " + f"req_lengths=[{', '.join(str(r.total_length) for r in batched_requests)}]" + ) + except Exception: + pass if not self._check_kv_cache_available(total_tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, @@ -629,6 +638,13 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic # Pre-check: Verify KV cache has enough space for decode (1 token per request) tokens_needed = batch_size + try: + available = self.model_runner.token_to_kv_pool_allocator.available_size() + logger.debug( + f"[KV Cache Decode] available={available}, needed={tokens_needed}, batch_size={batch_size}" + ) + except Exception: + pass if not self._check_kv_cache_available(tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 3657c180..90a392a7 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -492,7 +492,7 @@ async def _handle_loop(self): if is_finished: if recv_dict.get("abort", False): logger.warning(f"Request {rid} finished with abort") - request_info.finish_reason = "abort" + request_info.finish_reason = "stop" elif recv_dict.get("length", False): logger.debug(f"Request {rid} finished with length") request_info.finish_reason = "length" @@ -506,7 +506,7 @@ async def _handle_loop(self): request_info.matched_stop = next_token_id else: logger.debug(f"Request {rid} finished with unknown reason") - request_info.finish_reason = "unknown" + request_info.finish_reason = "stop" request_info.is_finish = True if request_info.stream: From ed665f1ad9044076f8c23f359c68bdf0d120b6a7 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 10 Feb 2026 20:51:35 +0800 Subject: [PATCH 33/36] pre-commit --- src/parallax/utils/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index b6107fc7..99dd9cb9 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -150,7 +150,7 @@ def get_tool_call_stop_token_ids(tokenizer) -> List[int]: # Markers whose token IDs should halt generation markers = [ "<|tool_calls_section_end|>", # Kimi K2 / K2.5 - "<|im_end|>", # common chat turn-end token + "<|im_end|>", # common chat turn-end token ] for marker in markers: From be1f910b3787c403df11ee04980383210d076d7c Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 11 Feb 2026 10:09:42 +0800 Subject: [PATCH 34/36] update --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 27a44d8a..ced7de2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,9 +55,9 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", "mlx-lm==0.28.4", - "mlx[cpu]==0.30.4", + "mlx[cpu]==0.30.0", # due to transformers version conflict, we need to install mlx-lm separately - # pip install mlx-lm==0.30.6 --no-deps + # pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps ] vllm = [ From e5e1b12b01ee3f2f7789c6a1b163211d284f7534 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 11 Feb 2026 10:38:19 +0800 Subject: [PATCH 35/36] update --- pyproject.toml | 4 ++-- tests/test_tool_call.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ced7de2f..27a44d8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,9 +55,9 @@ mac = [ gpu = [ "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", "mlx-lm==0.28.4", - "mlx[cpu]==0.30.0", + "mlx[cpu]==0.30.4", # due to transformers version conflict, we need to install mlx-lm separately - # pip install mlx-lm==0.30.6 "mlx[cpu]==0.30.4" --no-deps + # pip install mlx-lm==0.30.6 --no-deps ] vllm = [ diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py index b7bbbc33..0c1b2759 100644 --- a/tests/test_tool_call.py +++ b/tests/test_tool_call.py @@ -460,8 +460,8 @@ def test_made_tool_call_flag(self): self.assertTrue(state.made_tool_call) - def test_made_tool_call_flag_set_even_on_parse_failure(self): - """made_tool_call should be True even if parsing fails (marker was still entered).""" + def test_made_tool_call_flag_not_set_on_parse_failure(self): + """made_tool_call should be False if parsing fails (no valid tool call was produced).""" state = make_tool_state( tool_call_start="", tool_call_end="", @@ -472,7 +472,7 @@ def test_made_tool_call_flag_set_even_on_parse_failure(self): segment = "invalid" state.extract_from_segment(segment) - self.assertTrue(state.made_tool_call) + self.assertFalse(state.made_tool_call) def test_string_arguments_serialized_to_json(self): """Tool call arguments dict should be serialized to JSON string in output.""" From 69f34035241bd05bd3b15618b51f9703ac056693 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Wed, 11 Feb 2026 10:48:38 +0800 Subject: [PATCH 36/36] add metal detect --- src/parallax/utils/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 2190d9b0..0cf5a050 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -25,7 +25,11 @@ def is_mps_available(): def is_metal_available(): - """Check if MLX Metal backend is available""" + """Check if MLX Metal backend is available (macOS only)""" + import sys + + if sys.platform != "darwin": + return False try: import mlx.core as mx