diff --git a/docs/user_guide/install.md b/docs/user_guide/install.md index 02a30343..7c011a95 100644 --- a/docs/user_guide/install.md +++ b/docs/user_guide/install.md @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio ```sh git clone https://github.com/GradientHQ/parallax.git cd parallax -pip install -e '.[gpu]' +pip install -e ".[gpu]" && pip install mlx-lm==0.30.6 --no-deps ``` #### For macOS (Apple silicon): @@ -34,14 +34,14 @@ cd parallax python3 -m venv ./venv source ./venv/bin/activate -pip install -e '.[mac]' +pip install -e ".[mac]" ``` Next time to re-activate this virtual environment, run ```source ./venv/bin/activate```. #### Extra step for development: ```sh -pip install -e '.[dev]' +pip install -e ".[dev]" ``` ### Windows Application diff --git a/pyproject.toml b/pyproject.toml index 4c0ef40f..27a44d8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "dijkstar==2.6.0", "lattica==1.0.21", "orjson", + ] [project.scripts] @@ -47,12 +48,16 @@ mac = [ "torch==2.8.0", "mlx-lm==0.30.6", "mlx==0.30.4", + "mlx-vlm==0.3.10", + "torchvision==0.23.0" ] gpu = [ - "sglang[all]==0.5.7", + "sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python", "mlx-lm==0.28.4", - "mlx[cpu]==0.30.0", + "mlx[cpu]==0.30.4", + # due to transformers version conflict, we need to install mlx-lm separately + # pip install mlx-lm==0.30.6 --no-deps ] vllm = [ diff --git a/src/parallax/launch.py b/src/parallax/launch.py index b20f6250..9c52cc4f 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -26,6 +26,7 @@ from parallax.server.executor.factory import run_executor_process, stop_executor_process from parallax.server.http_server import launch_http_server, stop_http_server from parallax.server.server_args import parse_args +from parallax.utils.config_utils import get_config_value from parallax.utils.shared_state import SharedState from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port from parallax_utils.ascii_anime import display_parallax_join @@ -120,23 +121,25 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr check_latest_release() config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache) + num_layers = get_config_value(config, "num_hidden_layers") + if args.start_layer is None: args.start_layer = 0 if args.end_layer is None: - args.end_layer = config.get("num_hidden_layers") + args.end_layer = num_layers # only launch http server on head node if args.start_layer == 0: http_server_process = launch_http_server(args) # Launch P2P server as subprocess - if not (args.start_layer == 0 and args.end_layer == config.get("num_hidden_layers")): + if not (args.start_layer == 0 and args.end_layer == num_layers): p2p_server_process = launch_p2p_server_process( initial_peers=args.initial_peers, scheduler_addr=args.scheduler_addr, relay_servers=args.relay_servers, pp_start_layer=args.start_layer, pp_end_layer=args.end_layer, - hidden_layers=config.get("num_hidden_layers"), + hidden_layers=num_layers, tp_size=args.tp_size, dp_size=args.dp_size, tcp_port=args.tcp_port, diff --git a/src/parallax/models/kimi_vl.py b/src/parallax/models/kimi_vl.py new file mode 100644 index 00000000..f17434e5 --- /dev/null +++ b/src/parallax/models/kimi_vl.py @@ -0,0 +1,203 @@ +""" +Defines the KimiVL model for Parallax. + +KimiVL uses a DeepSeek-V3 based language model with MoE and a MoonViT vision encoder. +This module reuses components from mlx-vlm and adds PagedAttention support +for distributed inference. +""" + +from typing import Any, List, Optional + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention + +# Import from mlx-vlm kimi_vl language module +from mlx_vlm.models.kimi_vl.language import DeepseekV3Attention as MLXKimiVLAttention +from mlx_vlm.models.kimi_vl.language import ( + DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer, +) + +from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache +from parallax.server.cache.base import BaseCache +from parallax.utils.prefix_cache_utils import compute_attention_with_prefix_cache +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ParallaxKimiVLAttention(MLXKimiVLAttention): + """KimiVL (DeepSeek-V3) Attention with PagedAttention support for Parallax. + + This extends the MLX-VLM KimiVL attention (DeepseekV3Attention) with: + - Paged KV cache support for efficient memory management + - Block-table based attention for decode phase + - Prefix cache support for prefill phase + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[BaseCache] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + prefix_lens: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + batch, target_len, _ = x.shape + + # Q projection (with optional LoRA) + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + + # KV projection (with MQA compression) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + + kv = kv.reshape(batch, target_len, self.num_heads, -1) + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + k_nope = k_nope.transpose(0, 2, 1, 3) + + # Get KV cache + key_cache_global, value_cache_global = cache.get_cache() + + # Compute RoPE offsets + if target_len == 1: + # Decode phase: position is context_length - 1 + current_pos = context_lengths - 1 + elif prefix_lens is not None: + # Prefill phase with prefix cache + current_pos = prefix_lens + else: + # Prefill phase without prefix cache + current_pos = 0 + + # Apply RoPE + q_pe = self.rope(q_pe, offset=current_pos) + k_pe = self.rope(k_pe, offset=current_pos) + + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + # Cache update with PagedAttention + block_size = key_cache_global.shape[3] + + reshape_and_cache( + keys.transpose(0, 2, 1, 3), + values, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, + ) + + if target_len == 1: + # Decode phase: Use Paged Attention + output = paged_attention( + queries, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + self.num_heads, + v_head_dim=values.shape[-1], + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + else: + # Prefill phase + has_prefix_cache = prefix_lens is not None and bool(mx.any(prefix_lens > 0)) + + if has_prefix_cache: + k_new = keys + v_new = values.transpose(0, 2, 1, 3) + output = compute_attention_with_prefix_cache( + queries, + k_new, + v_new, + cache, + block_tables, + prefix_lens, + target_len, + self.scale, + self.num_heads, + mask=mask, + ) + else: + # Standard self-attention + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) + + output = scaled_dot_product_attention( + queries, + keys, + values.transpose(0, 2, 1, 3), + scale=self.scale, + mask=mask, + cache=None, + ) + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + + return self.o_proj(output) + + +class ParallaxKimiVLBlock(MLXKimiVLDecoderLayer): + """KimiVL Transformer block with PagedAttention support. + + Extends the MLX-VLM KimiVL decoder layer to use ParallaxKimiVLAttention + and pass through paged attention arguments. + """ + + def __init__(self, args, layer_idx: int, local_layer_idx: int): + super().__init__(args, layer_idx=layer_idx) + # Replace attention with Parallax version + self.self_attn = ParallaxKimiVLAttention(args) + self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + lengths: Optional[mx.array] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "KimiVLForConditionalGeneration" + + +EntryClass = ParallaxKimiVLBlock diff --git a/src/parallax/models/qwen3_vl.py b/src/parallax/models/qwen3_vl.py new file mode 100644 index 00000000..de294550 --- /dev/null +++ b/src/parallax/models/qwen3_vl.py @@ -0,0 +1,171 @@ +""" +Defines the Qwen3VL model for Parallax. + +This module reuses components from mlx-vlm and adds PagedAttention support +for distributed inference. +""" + +from typing import Any, List, Optional + +import mlx.core as mx +from mlx import nn + +# Import from mlx-vlm +from mlx_vlm.models.qwen3_vl.language import MLP +from mlx_vlm.models.qwen3_vl.language import Attention as MLXQwen3VLAttention +from mlx_vlm.models.qwen3_vl.language import apply_multimodal_rotary_pos_emb + +from parallax.server.cache.base import BaseCache +from parallax_extensions.ops import paged_attention_v1, reshape_and_cache +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ParallaxQwen3VLAttention(MLXQwen3VLAttention): + """Qwen3VL Attention with PagedAttention support for Parallax.""" + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[BaseCache] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + prefix_lens: Optional[mx.array] = None, + position_ids: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries.reshape(B, L, self.n_heads, self.head_dim)).transpose( + 0, 2, 1, 3 + ) + keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, self.head_dim)).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, self.head_dim) + + # Get KV cache + key_cache_global, value_cache_global = cache.get_cache() + + # Compute RoPE position + if L == 1: + # Decode phase: use context_lengths - 1 as offset + current_pos = context_lengths - 1 + pos_ids = mx.broadcast_to(current_pos[:, None], (B, L)) + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + elif position_ids is not None: + # Prefill with MRoPE position_ids + pos_ids = position_ids + elif prefix_lens is not None: + # Prefill with prefix cache + pos_ids = mx.arange(L)[None, :] + prefix_lens[:, None] + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + else: + # Standard prefill + pos_ids = mx.arange(L)[None, :] + pos_ids = mx.broadcast_to(pos_ids, (B, L)) + pos_ids = mx.broadcast_to(pos_ids[None, :, :], (3, B, L)) + + cos, sin = self.rotary_emb(values, pos_ids) + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + # Ensure dtype consistency with cache (RoPE may output float32) + cache_dtype = key_cache_global.dtype + if keys.dtype != cache_dtype: + keys = keys.astype(cache_dtype) + if values.dtype != cache_dtype: + values = values.astype(cache_dtype) + if queries.dtype != cache_dtype: + queries = queries.astype(cache_dtype) + + # Cache update with PagedAttention + block_size = key_cache_global.shape[3] + reshape_and_cache( + keys.transpose(0, 2, 1, 3), + values, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + slot_mapping=slot_mapping, + ) + + # Compute attention + if L == 1: + # Decode: use PagedAttention + output = paged_attention_v1( + queries, + key_cache_global, + value_cache_global, + block_tables, + context_lengths, + block_size, + self.scale, + self.n_kv_heads, + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + else: + # Prefill: standard attention + from mlx_lm.models.base import scaled_dot_product_attention + + output = scaled_dot_product_attention( + queries, + keys, + values.transpose(0, 2, 1, 3), + scale=self.scale, + mask=mask, + cache=None, + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class ParallaxQwen3VLBlock(nn.Module): + """Qwen3VL Transformer block with PagedAttention support.""" + + def __init__(self, args, layer_idx: int, local_layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = ParallaxQwen3VLAttention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.layer_idx = layer_idx + self.local_layer_idx = local_layer_idx + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[List[Any]] = None, + block_tables: Optional[mx.array] = None, + context_lengths: Optional[mx.array] = None, + slot_mapping: Optional[mx.array] = None, + **kwargs, + ): + r = self.self_attn( + self.input_layernorm(x), + mask, + cache[self.local_layer_idx], + block_tables=block_tables, + context_lengths=context_lengths, + slot_mapping=slot_mapping, + **kwargs, + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "Qwen3VLForConditionalGeneration" + + +EntryClass = ParallaxQwen3VLBlock diff --git a/src/parallax/server/executor/base_executor.py b/src/parallax/server/executor/base_executor.py index 7180242c..08a0fcf8 100755 --- a/src/parallax/server/executor/base_executor.py +++ b/src/parallax/server/executor/base_executor.py @@ -39,12 +39,14 @@ IntermediateRequest, Request, RequestStatus, + VLMInputs, ) from parallax.server.sampling.sampling_params import SamplingParams from parallax.server.scheduler import Scheduler +from parallax.utils.config_utils import ModelConfigAccessor from parallax.utils.shared_state import SharedState from parallax.utils.utils import get_current_device, get_device_dtype, get_zmq_socket -from parallax_utils.logging_config import get_logger +from parallax_utils.logging_config import get_logger, set_rank logger = get_logger(__name__) @@ -112,13 +114,24 @@ def __init__( # Pipe communication self.conn = conn + # Use VLM-aware config accessor for unified config access + self._config_accessor = ModelConfigAccessor(self.config) + self.is_vlm = self._config_accessor.is_vlm + + # Get num_hidden_layers using unified config accessor + num_hidden_layers = self._config_accessor.get_num_hidden_layers() + self.is_first_peer = start_layer == 0 - self.is_last_peer = end_layer == self.config.get("num_hidden_layers") + self.is_last_peer = end_layer == num_hidden_layers self.tp_size = tp_size self.tp_rank = tp_rank self.dp_size = dp_size self.dp_rank = dp_rank + # Configure logging to only print on rank 0 when using multiple GPUs + if tp_size > 1: + set_rank(tp_rank, enable_filter=True) + # Runtime weight refit for RL self.enable_weight_refit = enable_weight_refit self.weight_version = 0 @@ -143,7 +156,16 @@ def __init__( else: self.pad_token_id = self.tokenizer.pad_token_id - self.eos_token_id = self.config.get("eos_token_id", None) + self.eos_token_id = self._config_accessor.get_eos_token_id() + + # Ensure <|im_end|> is in the EOS token list. Some models (e.g. + # Kimi-K2.5) use <|im_end|> to end assistant turns but only list + # [EOS] in config.json. Without this the scheduler will never + # detect end-of-turn and generation runs until max_tokens. + self._augment_eos_with_im_end() + + # Build multimodal config (only meaningful for VLM models) + self.mm_config = self._config_accessor.build_mm_config() # Scheduler: derive final max_batch_size with KV constraints # Remove this for now as it's not working on gpu devices @@ -207,7 +229,8 @@ def __init__( f"(layers [{self.start_layer}, {self.end_layer}), " f"tp_rank={self.tp_rank}/{self.tp_size}, " f"device={self.device}, " - f"num_shard_layers={self.num_shard_layers})" + f"num_shard_layers={self.num_shard_layers}, " + f"is_vlm={self.is_vlm})" ) @abstractmethod @@ -611,15 +634,51 @@ def shutdown(self): logger.debug("Executor shutdown complete.") - def _handle_raw_request(self, raw_request: Dict): - assert "messages" in raw_request, "Request did not contain messages" + def _augment_eos_with_im_end(self): + """Add ``<|im_end|>`` to the EOS token list when it is present in the + vocabulary but missing from the configured ``eos_token_id``. - rid = raw_request["rid"] + Many chat models (Kimi-K2.5, Qwen, etc.) use ``<|im_end|>`` as the + turn-ending token, yet their ``config.json`` only lists ``[EOS]`` as + the EOS token. Without this augmentation the scheduler will never + detect end-of-turn and generation will run until ``max_tokens``. + """ + _get_vocab = getattr(self.tokenizer, "get_vocab", None) + vocab = _get_vocab() if _get_vocab else {} + im_end_id = vocab.get("<|im_end|>") + if im_end_id is None: + return + + # Normalise eos_token_id to a list for easy comparison + if self.eos_token_id is None: + self.eos_token_id = [im_end_id] + logger.info(f"Set eos_token_id to [{im_end_id}] (<|im_end|>)") + elif isinstance(self.eos_token_id, list): + if im_end_id not in self.eos_token_id: + self.eos_token_id.append(im_end_id) + logger.info(f"Added <|im_end|> (id={im_end_id}) to eos_token_id list") + elif isinstance(self.eos_token_id, int): + if self.eos_token_id != im_end_id: + self.eos_token_id = [self.eos_token_id, im_end_id] + logger.info( + f"Expanded eos_token_id to {self.eos_token_id} " + f"(added <|im_end|> id={im_end_id})" + ) + + def _process_text_request(self, rid: str, messages: list, raw_request: Dict) -> list: + """Process a text-only request using the tokenizer.""" if self.tokenizer.chat_template: - messages = raw_request["messages"] - process_message_content(messages) + has_non_text_content = any( + isinstance(msg.get("content"), list) + and any( + isinstance(part, dict) and part.get("type") != "text" + for part in msg.get("content") + ) + for msg in messages + ) + if not has_non_text_content: + process_message_content(messages) chat_template_kwargs = raw_request.get("chat_template_kwargs", {}) - # check extra_body for backward compatibility if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]: chat_template_kwargs.update(raw_request["extra_body"]["chat_template_kwargs"]) @@ -631,27 +690,81 @@ def _handle_raw_request(self, raw_request: Dict): **chat_template_kwargs, ) else: - prompt = convert_chat(raw_request["messages"], raw_request.get("role_mapping")) + prompt = convert_chat(messages, raw_request.get("role_mapping")) prompt = self.tokenizer.encode(prompt) + return prompt + + def _process_request_prompt( + self, rid: str, messages: list, image_urls: list, raw_request: Dict + ) -> Tuple[list, Optional[VLMInputs]]: + """ + Process request messages and return (input_ids, vlm_inputs). + + Subclasses can override this method to implement custom VLM processing. + Default implementation only handles text requests. + + Args: + rid: Request ID + messages: List of message dicts + image_urls: List of image URLs extracted from messages + raw_request: Original raw request dict + + Returns: + Tuple of (input_ids, vlm_inputs). vlm_inputs is None for text-only requests. + """ + # Default: text-only processing, subclasses override for VLM support + prompt = self._process_text_request(rid, messages, raw_request) + return prompt, None + + def _handle_raw_request(self, raw_request: Dict): + assert "messages" in raw_request, "Request did not contain messages" + + rid = raw_request["rid"] + messages = raw_request["messages"] + + # Extract image URLs first to determine if this is a multimodal request + multimodal_params = None + image_urls = [] + for message in messages: + content = message.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + if multimodal_params is None: + multimodal_params = {"images": []} + image_url = part["image_url"] + multimodal_params["images"].append(image_url) + # Extract URL string for processing + if isinstance(image_url, dict): + image_urls.append(image_url.get("url", image_url)) + else: + image_urls.append(image_url) + + # Process request prompt (subclasses can override for VLM support) + prompt, vlm_inputs = self._process_request_prompt(rid, messages, image_urls, raw_request) + max_seq_len = self.max_sequence_length if self.max_sequence_length is not None else 4096 max_seq_len = max(max_seq_len, 4096) max_new_tokens = raw_request.get("max_tokens", 2048) input_token_num = len(prompt) if input_token_num + max_new_tokens >= max_seq_len: - logger.warning( - f"Input token length {input_token_num} + max_new_tokens {max_new_tokens} exceeds max_sequence_length {max_seq_len}." - ) - if max_new_tokens > 2048: + available_new_tokens = max_seq_len - input_token_num + if available_new_tokens <= 0: logger.warning( - f"max_new_tokens {max_new_tokens} is too large, reduce to 2048 tokens." + f"Input token length {input_token_num} already exceeds " + f"max_sequence_length {max_seq_len}. " + f"Truncating input to keep last {max_seq_len - 1} tokens." ) - max_new_tokens = 2048 - if input_token_num + max_new_tokens >= max_seq_len: + prompt = prompt[-(max_seq_len - 1) :] + max_new_tokens = 1 + else: logger.warning( - f"Trunc input prompt, keep last {max_seq_len - max_new_tokens} tokens" + f"Input token length {input_token_num} + max_new_tokens {max_new_tokens} " + f"exceeds max_sequence_length {max_seq_len}. " + f"Reducing max_new_tokens to {available_new_tokens}." ) - prompt = prompt[-(max_seq_len - max_new_tokens) :] + max_new_tokens = available_new_tokens max_total_length = len(prompt) + max_new_tokens logger.debug(f"Final max_new_tokens for request ID {rid}: {max_new_tokens}") @@ -675,6 +788,29 @@ def _handle_raw_request(self, raw_request: Dict): if "ignore_eos" in raw_sampling_params: sampling_params.ignore_eos = raw_sampling_params["ignore_eos"] + # Also read OpenAI-style top-level sampling parameters as fallback + if "temperature" in raw_request and raw_sampling_params is None: + sampling_params.temperature = raw_request["temperature"] + if sampling_params.temperature == 0.0: + sampling_params.temperature = 1.0 + sampling_params.top_k = 1 + if "top_p" in raw_request and raw_sampling_params is None: + sampling_params.top_p = raw_request["top_p"] + + # When tools are present, add tool-call-related stop token IDs so the + # scheduler halts generation at the tool-call boundary instead of + # running until max_tokens. + tools = raw_request.get("tools") + if tools and self.tokenizer is not None: + from parallax.utils.tokenizer_utils import get_tool_call_stop_token_ids + + tool_stop_ids = get_tool_call_stop_token_ids(self.tokenizer) + if tool_stop_ids: + if sampling_params.stop_token_ids is None: + sampling_params.stop_token_ids = set() + sampling_params.stop_token_ids.update(tool_stop_ids) + logger.debug(f"Added tool call stop token IDs for request {rid}: {tool_stop_ids}") + req = InitialRequest( request_id=rid, output_ids=None, @@ -684,6 +820,8 @@ def _handle_raw_request(self, raw_request: Dict): max_total_length=max_total_length, lora_path=lora_path, return_probs=return_probs, + multimodal_params=multimodal_params, + vlm_inputs=vlm_inputs, ) if "routing_table" in raw_request: req.routing_table = raw_request["routing_table"] diff --git a/src/parallax/server/executor/factory.py b/src/parallax/server/executor/factory.py index 20ed93e4..8209d9fd 100755 --- a/src/parallax/server/executor/factory.py +++ b/src/parallax/server/executor/factory.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional from parallax.utils.utils import get_current_device -from parallax_utils.logging_config import get_logger, set_log_level +from parallax_utils.logging_config import get_logger, set_log_level, set_rank logger = get_logger(__name__) @@ -109,7 +109,23 @@ def create_from_args( def run_executor_process(args, shared_state=None, conn=None): """Run executor as a subprocess""" + # Set rank to suppress logs on non-zero ranks + # Must be called AFTER set_log_level to override the level + tp_rank = getattr(args, "tp_rank", 0) + tp_size = getattr(args, "tp_size", 1) + + # For non-zero ranks, suppress logs before any imports + if tp_size > 1 and tp_rank != 0: + import logging + + logging.getLogger().setLevel(logging.CRITICAL + 1) + set_log_level(args.log_level) + + # Now set rank properly (will re-suppress for non-zero ranks) + if tp_size > 1: + set_rank(tp_rank, enable_filter=True) + executor = None try: executor = create_from_args(args, shared_state, conn) diff --git a/src/parallax/server/executor/mlx_executor.py b/src/parallax/server/executor/mlx_executor.py index c3ab5be3..a7202fa7 100755 --- a/src/parallax/server/executor/mlx_executor.py +++ b/src/parallax/server/executor/mlx_executor.py @@ -20,6 +20,7 @@ ) from parallax.server.sampling.sampler import SamplingBatchInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.config_utils import get_config_value from parallax.utils.utils import ( combine_padding_and_causal_masks, create_causal_mask, @@ -122,9 +123,48 @@ def __init__( self.model_shard = self.shard_loader.load_lora(self.model_shard, adapters) logger.debug( - f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={self.config.get('num_hidden_layers')}" + f"MLX sharded model loaded in {(time.time() - t0) * 1000:.1f} ms; num_layers={get_config_value(self.config, 'num_hidden_layers')}" ) + # Load VLM processor if this is a VLM model (first peer only) + self.vlm_processor = None + self.model_type = self.config.get("model_type") + if hasattr(self.model_shard, "is_vlm") and self.model_shard.is_vlm and start_layer == 0: + processor_path = self.shard_loader.model_path_str + logger.debug(f"Trying to load VLM processor from: {processor_path}") + + processor_loaded = False + + try: + from transformers import AutoProcessor + + self.vlm_processor = AutoProcessor.from_pretrained( + processor_path, + trust_remote_code=True, + ) + processor_type = type(self.vlm_processor).__name__ + # Verify it has image processing capability + if ( + hasattr(self.vlm_processor, "image_processor") + and self.vlm_processor.image_processor is not None + ): + logger.info( + f"Loaded VLM processor (AutoProcessor -> {processor_type}) for {self.model_type}" + ) + processor_loaded = True + else: + logger.warning( + f"AutoProcessor loaded {processor_type} but it doesn't have image_processor, skipping" + ) + self.vlm_processor = None + except Exception as e: + logger.debug(f"AutoProcessor failed: {e}") + + if not processor_loaded: + logger.warning( + "VLM image processing will be disabled - no processor could be loaded." + ) + # TODO: Duplicate code to BaseExecutor since num_shard_layers and dtype are needed for initializing kv cache self.num_shard_layers = end_layer - start_layer self.dtype = get_device_dtype(dtype, device) @@ -133,13 +173,21 @@ def __init__( ) # Calculate feature dimensions for kv cache - num_key_value_heads = self.config.get("num_key_value_heads") + # Use helper to handle VLM models where these are in text_config + num_key_value_heads = get_config_value(self.config, "num_key_value_heads") if num_key_value_heads is None: # Step3.5 flash use num_attention_groups instead. - num_key_value_heads = self.config.get("num_attention_groups") - head_dim = self.config.get("head_dim") or self.config.get("hidden_size") // self.config.get( - "num_attention_heads" - ) + num_key_value_heads = get_config_value(self.config, "num_attention_groups") + head_dim = get_config_value(self.config, "head_dim") + if head_dim is None: + hidden_size = get_config_value(self.config, "hidden_size") + num_attention_heads = get_config_value(self.config, "num_attention_heads") + if hidden_size and num_attention_heads: + head_dim = hidden_size // num_attention_heads + else: + raise ValueError( + f"Cannot determine head_dim: hidden_size={hidden_size}, num_attention_heads={num_attention_heads}" + ) qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) if qk_nope_head_dim is not None and qk_rope_head_dim is not None: @@ -247,6 +295,139 @@ def __init__( f"mlx_executor initialized; wired_limit set; prefix_cache={'on' if self.enable_prefix_cache else 'off'}, total memory usage: {mx.get_active_memory() / 1024**3 :.3f} GB" ) + # ========== VLM Processing Methods ========== + + def _process_request_prompt( + self, rid: str, messages: list, image_urls: list, raw_request: Dict + ) -> Tuple[list, Optional[Any]]: + """ + Override base class method to handle VLM requests for MLX. + + If VLM processor is available and images are present, use VLM processing. + Otherwise, fall back to text-only processing. + """ + if image_urls and self.vlm_processor is not None: + return self._process_vlm_request(rid, messages, image_urls) + else: + prompt = self._process_text_request(rid, messages, raw_request) + return prompt, None + + def _process_vlm_request(self, rid: str, messages: list, image_urls: list): + """Process a VLM (multimodal) request using the VLM processor.""" + from parallax.server.request import VLMInputs + from parallax.utils.vlm_utils import load_image + + try: + images = [] + for url in image_urls: + try: + img = load_image(url) + images.append(img) + except Exception as e: + logger.warning(f"Failed to load image {url}: {e}") + + if not images: + logger.warning( + f"No images loaded for VLM request {rid}, falling back to text processing" + ) + return self._process_text_request(rid, messages, {}), None + + formatted_messages = self._format_messages_for_vlm(messages) + + if hasattr(self.vlm_processor, "apply_chat_template"): + text_prompt = self.vlm_processor.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + elif self.tokenizer.chat_template: + text_prompt = self.tokenizer.apply_chat_template( + formatted_messages, + tokenize=False, + add_generation_prompt=True, + ) + else: + text_prompt = "\n".join( + f"{msg.get('role', 'user')}: {self._extract_text_from_content(msg.get('content', ''))}" + for msg in formatted_messages + ) + + processor_inputs = self.vlm_processor( + text=text_prompt, + images=images, + return_tensors="pt", + ) + input_ids = processor_inputs.get("input_ids") + if input_ids is None: + raise ValueError("Processor did not return input_ids") + prompt = input_ids.flatten().tolist() + + pixel_values = processor_inputs.get("pixel_values") + image_grid_thw = processor_inputs.get("image_grid_thw") + + vlm_inputs = VLMInputs( + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + image_sizes=[(img.height, img.width) for img in images], + images_processed=False, + ) + + logger.debug( + f"VLM request {rid}: {len(images)} images, " + f"input_ids length={len(prompt)}, " + f"pixel_values shape={pixel_values.shape if pixel_values is not None else None}" + ) + + return prompt, vlm_inputs + + except Exception as e: + logger.error(f"Failed to process VLM request {rid}: {e}") + return self._process_text_request(rid, messages, {}), None + + def _format_messages_for_vlm(self, messages: list) -> list: + """ + Format messages for VLM processing. + Keep the original structure so chat template can handle image placeholders correctly. + """ + formatted = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict): + if part.get("type") == "text": + new_content.append({"type": "text", "text": part.get("text", "")}) + elif part.get("type") == "image_url": + new_content.append({"type": "image"}) + elif isinstance(part, str): + new_content.append({"type": "text", "text": part}) + formatted.append( + { + "role": msg.get("role", "user"), + "content": new_content, + } + ) + else: + formatted.append(msg) + return formatted + + def _extract_text_from_content(self, content) -> str: + """Extract text from message content (handles both string and list formats).""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + texts.append(part.get("text", "")) + elif isinstance(part, str): + texts.append(part) + return " ".join(texts) + return str(content) + + # ========== End VLM Processing Methods ========== + def _tensor_parallel_broadcast_pyobj(self, broadcast_obj): """Wrapper for broadcast pyobject in TP group using send/recv with explicit sync""" if self.tp_size <= 1: @@ -395,6 +576,9 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: slot_mapping=prepared_inputs.get("slot_mapping"), state_slot_mapping=prepared_inputs.get("state_slot_mapping"), prefix_lens=prepared_inputs.get("prefix_lens"), # For RoPE offset in prefix cache + # VLM inputs (only present for first peer with VLM model) + pixel_values=prepared_inputs.get("pixel_values"), + image_grid_thw=prepared_inputs.get("image_grid_thw"), ) mx.eval(hidden_states) @@ -661,6 +845,38 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A "prefix_lens": prefix_lens_tensor, # For RoPE offset calculation "actual_processed_lengths": actual_processed_lengths_tensor, # For correct logit selection } + + # VLM support: collect pixel_values and image metadata for first peer + if self.is_first_peer and hasattr(self.model_shard, "is_vlm") and self.model_shard.is_vlm: + pixel_values_list = [] + image_grid_thw_list = [] + has_vlm_inputs = False + + for req in batched_requests: + if req.vlm_inputs is not None and req.vlm_inputs.has_images(): + has_vlm_inputs = True + pixel_values_list.append(req.vlm_inputs.pixel_values) + if req.vlm_inputs.image_grid_thw is not None: + image_grid_thw_list.append(req.vlm_inputs.image_grid_thw) + else: + pixel_values_list.append(None) + + if has_vlm_inputs: + # For now, we only support single-image batching where all requests + # in the batch either have images or don't have images + # TODO: Support mixed batches with proper padding + non_none_pixels = [p for p in pixel_values_list if p is not None] + if len(non_none_pixels) > 0: + # Concatenate all pixel values along batch dimension + ret["pixel_values"] = mx.concatenate( + [mx.array(p) for p in non_none_pixels], axis=0 + ) + if image_grid_thw_list: + ret["image_grid_thw"] = mx.concatenate( + [mx.array(g) for g in image_grid_thw_list], axis=0 + ) + logger.debug(f"VLM batch: pixel_values shape={ret['pixel_values'].shape}") + logger.debug(f"Prepared MLX prefill batch (size={batch_size})") return ret diff --git a/src/parallax/server/executor/sglang_executor.py b/src/parallax/server/executor/sglang_executor.py index 555db81c..4a7a4e01 100755 --- a/src/parallax/server/executor/sglang_executor.py +++ b/src/parallax/server/executor/sglang_executor.py @@ -6,7 +6,9 @@ from typing import Any, Dict, List, Optional, Tuple import torch +from sglang.srt.environ import envs from sglang.srt.lora.lora_registry import LoRARef +from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache @@ -140,12 +142,14 @@ def __init__( logger.debug( f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( - **model_runner_params + self.model_runner, self.config, self.tokenizer, self.processor = ( + initialize_sgl_model_runner(**model_runner_params) ) logger.debug( f"SGLang model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) + embedding_cache_size = envs.SGLANG_VLM_CACHE_SIZE_MB.get() + init_mm_embedding_cache(embedding_cache_size * 1024 * 1024) # Set device to specific CUDA device based on tp_rank # This ensures tensors are moved to the correct GPU @@ -543,6 +547,15 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A # Pre-check: Verify KV cache has enough space for prefill total_tokens_needed = sum(req.total_length for req in batched_requests) + try: + available = self.model_runner.token_to_kv_pool_allocator.available_size() + logger.debug( + f"[KV Cache Prefill] available={available}, needed={total_tokens_needed}, " + f"batch_size={batch_size}, " + f"req_lengths=[{', '.join(str(r.total_length) for r in batched_requests)}]" + ) + except Exception: + pass if not self._check_kv_cache_available(total_tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, @@ -599,6 +612,9 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A batched_requests, self.model_runner, self.page_tree_cache, + self.processor, + self.mm_config, + self.tokenizer, ) self.cur_batch = schedule_batch @@ -622,6 +638,13 @@ def _prepare_decode_batch(self, batched_requests: List[Request]) -> Optional[Dic # Pre-check: Verify KV cache has enough space for decode (1 token per request) tokens_needed = batch_size + try: + available = self.model_runner.token_to_kv_pool_allocator.available_size() + logger.debug( + f"[KV Cache Decode] available={available}, needed={tokens_needed}, batch_size={batch_size}" + ) + except Exception: + pass if not self._check_kv_cache_available(tokens_needed): self._abort_requests_due_to_kv_cache( batched_requests, diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index ad3ca173..90a392a7 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -46,6 +46,77 @@ logger = get_logger(__name__) +# --------------------------------------------------------------------------- +# Kimi-K2.5 immutable parameter constraints +# These parameters are locked to specific values and the API must reject +# any request that attempts to override them with different values. +# See: https://github.com/MoonshotAI/Kimi-Vendor-Verifier +# --------------------------------------------------------------------------- +_KIMI_K25_CONSTRAINTS = { + "think": { + "temperature": 1.0, + "top_p": 0.95, + "presence_penalty": 0, + "frequency_penalty": 0, + "n": 1, + }, + "non_think": { + "temperature": 0.6, + "top_p": 0.95, + "presence_penalty": 0, + "frequency_penalty": 0, + "n": 1, + }, +} + +# Model names that should enforce immutable parameter constraints. +# Add more model name patterns as needed. +_KIMI_K25_MODEL_KEYWORDS = ("kimi-k2.5", "kimi_k2.5", "kimi-k2-5") + + +def _is_kimi_k25_model(model_name: str) -> bool: + """Check if the model name matches Kimi-K2.5.""" + name_lower = model_name.lower() + return any(kw in name_lower for kw in _KIMI_K25_MODEL_KEYWORDS) + + +def _detect_thinking_mode(request_json: dict) -> bool: + """Detect whether the request uses thinking mode.""" + # opensource format: chat_template_kwargs.thinking + ctk = request_json.get("chat_template_kwargs", {}) + if isinstance(ctk, dict) and ctk.get("thinking"): + return True + # kimi SaaS format: thinking.type == "enabled" + thinking = request_json.get("thinking", {}) + if isinstance(thinking, dict) and thinking.get("type") == "enabled": + return True + return False + + +def validate_kimi_k25_params(request_json: dict) -> Optional[str]: + """Validate immutable parameters for Kimi-K2.5 models. + + Returns an error message string if validation fails, or None if OK. + """ + model = request_json.get("model", "") + if not _is_kimi_k25_model(model): + return None # Not a Kimi-K2.5 model, skip + + is_thinking = _detect_thinking_mode(request_json) + mode = "think" if is_thinking else "non_think" + constraints = _KIMI_K25_CONSTRAINTS[mode] + mode_label = "thinking" if is_thinking else "non-thinking" + + for param, expected in constraints.items(): + if param in request_json and request_json[param] != expected: + return ( + f"Invalid parameter for {model} ({mode_label} mode): " + f"'{param}' must be {expected}, got {request_json[param]}" + ) + + return None + + def get_exception_traceback(): """Traceback function to handle asyncio function errors""" etype, value, tb = sys.exc_info() @@ -76,7 +147,7 @@ class HTTPRequestInfo: finish_reason: str = None object: str = "chat.completion" model: str = "default" - create_time: float = 0.0 + create_time: int = 0 update_time: float = 0.0 logprobs: float = None matched_stop: int = None @@ -145,7 +216,7 @@ def create_request(self, request: Dict): return_probs = request.get("return_probs", False) # Check if probs requested chat_object = "chat.completion.chunk" if stream else "chat.completion" detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap) - create_time = time.time() + create_time = int(time.time()) update_time = create_time request_info = HTTPRequestInfo( id=rid, @@ -421,7 +492,7 @@ async def _handle_loop(self): if is_finished: if recv_dict.get("abort", False): logger.warning(f"Request {rid} finished with abort") - request_info.finish_reason = "abort" + request_info.finish_reason = "stop" elif recv_dict.get("length", False): logger.debug(f"Request {rid} finished with length") request_info.finish_reason = "length" @@ -435,7 +506,7 @@ async def _handle_loop(self): request_info.matched_stop = next_token_id else: logger.debug(f"Request {rid} finished with unknown reason") - request_info.finish_reason = "unknown" + request_info.finish_reason = "stop" request_info.is_finish = True if request_info.stream: @@ -494,6 +565,11 @@ async def v1_chat_completions(raw_request: fastapi.Request): except Exception as e: return create_error_response("Invalid request body, error: ", str(e)) + # Validate immutable parameter constraints (e.g. Kimi-K2.5) + param_error = validate_kimi_k25_params(request_json) + if param_error is not None: + return create_error_response(param_error, "BadRequestError", HTTPStatus.BAD_REQUEST) + # Check if request_json has "rid", otherwise generate new one request_id = request_json.get("rid") if request_id is None: diff --git a/src/parallax/server/model.py b/src/parallax/server/model.py index 54855c94..06d8d970 100644 --- a/src/parallax/server/model.py +++ b/src/parallax/server/model.py @@ -2,7 +2,8 @@ Defines the ShardedModel class for distributing MLX models across multiple devices. """ -from typing import Any, List, Optional, Type +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type import mlx.core as mx from mlx import nn @@ -14,12 +15,55 @@ logger = get_logger(__name__) +class VisionConfig: + """Dynamic configuration for vision models in VLM. + + This class dynamically accepts all parameters from config.json's vision_config, + making it compatible with different VLM architectures (Qwen-VL, LLaVA, etc.). + """ + + def __init__(self, **kwargs): + # Set all provided parameters as attributes + for key, value in kwargs.items(): + setattr(self, key, value) + + # Set common defaults if not provided + if not hasattr(self, "model_type"): + self.model_type = "clip_vision_model" + if not hasattr(self, "hidden_size"): + self.hidden_size = 1024 + + @classmethod + def from_dict(cls, params: Dict[str, Any]) -> "VisionConfig": + """Create VisionConfig from a dictionary, accepting all parameters.""" + if params is None: + return None + return cls(**params) + + +@dataclass +class InputEmbeddingsOutput: + """Output from get_input_embeddings method.""" + + inputs_embeds: mx.array + attention_mask: Optional[mx.array] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "inputs_embeds": self.inputs_embeds, + "attention_mask": self.attention_mask, + } + + class ShardedModel(nn.Module): """A general class for MLX sharded model, adapted for Parallax KV cache. Assumes self.layers are composed of modules (e.g., TransformerBlocks) that internally use ParallaxAttention and their __call__ method returns (hidden_state, new_k_for_layer, new_v_for_layer). + + Supports VLM (Vision Language Models) by optionally loading vision_tower + and multi_modal_projector on the first shard. """ def __init__( @@ -32,6 +76,13 @@ def __init__( *, has_norm_in: bool = False, dtype: Optional[mx.Dtype] = None, + # VLM support + vision_config: Optional[Dict[str, Any]] = None, + vision_tower_class: Optional[Type[nn.Module]] = None, + multi_modal_projector_class: Optional[Type[nn.Module]] = None, + image_token_index: Optional[int] = None, + vision_feature_layer: Optional[int] = -2, + vision_feature_select_strategy: str = "default", ): super().__init__() self.config = config @@ -49,13 +100,50 @@ def __init__( self.is_last_shard = end_layer == config.num_hidden_layers self.n_layers_in_shard = end_layer - start_layer + # VLM configuration + self.is_vlm = vision_config is not None and vision_tower_class is not None + self.vision_config = VisionConfig.from_dict(vision_config) if vision_config else None + self.image_token_index = image_token_index + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + if self.is_first_shard: self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) if has_norm_in: self.norm_in = nn.RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + if self.is_vlm: + logger.info( + f"Initializing VLM components: vision_tower ({self.vision_config.model_type})" + ) + self.vision_tower = vision_tower_class(self.vision_config) + if multi_modal_projector_class is not None: + try: + self.multi_modal_projector = multi_modal_projector_class(config) + except (TypeError, AttributeError): + combined_config = type( + "CombinedConfig", + (), + { + "vision_config": self.vision_config, + "text_config": config, + }, + )() + self.multi_modal_projector = multi_modal_projector_class(combined_config) + logger.info("Initialized projector with combined vision+text config") + else: + self.multi_modal_projector = None + logger.info( + "No separate projector class - projector is integrated into VisionModel" + ) + else: + self.vision_tower = None + self.multi_modal_projector = None else: self.embed_tokens = None self.norm_in = None + self.vision_tower = None + self.multi_modal_projector = None self.layers = [ block_class(config, layer_idx, layer_idx - start_layer) @@ -82,6 +170,165 @@ def shard_layers(self): ) exit(1) + def get_input_embeddings( + self, + input_ids: mx.array, + pixel_values: Optional[mx.array] = None, + **kwargs, + ) -> InputEmbeddingsOutput: + """Get input embeddings, optionally with vision features merged in. + + This method handles: + 1. Text-only inputs: Simply embed tokens + 2. VLM inputs: Embed tokens, encode images, and merge vision features + + Args: + input_ids: (batch, seq_len) Token IDs + pixel_values: (batch, C, H, W) or (num_patches, C, H, W) Image pixel values + **kwargs: Additional arguments (e.g., image_grid_thw for Qwen2-VL) + + Returns: + InputEmbeddingsOutput with merged embeddings + """ + if not self.is_first_shard: + raise ValueError("get_input_embeddings should only be called on the first shard") + + inputs_embeds = self.embed_tokens(input_ids) + if pixel_values is None or not self.is_vlm: + return InputEmbeddingsOutput(inputs_embeds=inputs_embeds) + image_features = self._encode_images(pixel_values, **kwargs) + final_embeds = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids + ) + + return InputEmbeddingsOutput(inputs_embeds=final_embeds) + + def _encode_images( + self, + pixel_values: mx.array, + image_grid_thw: Optional[mx.array] = None, + **kwargs, + ) -> mx.array: + """Encode images through vision tower and projector.""" + if self.vision_tower is None: + raise ValueError("Vision tower not initialized for this model") + + model_type = getattr(self.vision_config, "model_type", "") if self.vision_config else "" + is_qwen_vl = "qwen" in model_type.lower() and "vl" in model_type.lower() + is_moonvit = model_type.lower() == "moonvit" + uses_grid_thw = is_qwen_vl or is_moonvit + + # Ensure correct dtype + if hasattr(self.vision_tower, "patch_embed") and hasattr( + self.vision_tower.patch_embed, "proj" + ): + pixel_values = pixel_values.astype(self.vision_tower.patch_embed.proj.weight.dtype) + else: + pixel_values = pixel_values.astype(self.dtype) + + if uses_grid_thw and image_grid_thw is not None: + if is_moonvit: + # MoonViT (KimiVL) expects NHWC input + if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: + pixel_values = pixel_values.transpose(0, 2, 3, 1) + vision_outputs = self.vision_tower( + pixel_values, grid_thw=image_grid_thw, output_hidden_states=True + ) + else: + # Qwen-VL expects flat patches + vision_outputs = self.vision_tower(pixel_values, image_grid_thw) + + if isinstance(vision_outputs, tuple): + selected_features = vision_outputs[0] + elif isinstance(vision_outputs, list): + selected_features = vision_outputs + else: + selected_features = vision_outputs + else: + # CLIP/SigLIP style: NCHW -> NHWC + if pixel_values.ndim == 4 and pixel_values.shape[1] in [1, 3, 4]: + pixel_values = pixel_values.transpose(0, 2, 3, 1) + + vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True) + + if isinstance(vision_outputs, tuple): + if len(vision_outputs) >= 3: + hidden_states = vision_outputs[2] + if isinstance(self.vision_feature_layer, int): + selected_features = hidden_states[self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + selected_features = selected_features[:, 1:] + else: + hs_pool = [hidden_states[idx] for idx in self.vision_feature_layer] + if self.vision_feature_select_strategy == "default": + hs_pool = [hs[:, 1:] for hs in hs_pool] + selected_features = mx.concatenate(hs_pool, axis=-1) + else: + selected_features = vision_outputs[1] + if self.vision_feature_select_strategy == "default": + selected_features = selected_features[:, 1:] + else: + selected_features = vision_outputs + + if self.multi_modal_projector is not None: + image_features = self.multi_modal_projector(selected_features) + else: + image_features = selected_features + + return image_features + + def _merge_input_ids_with_image_features( + self, + image_features: mx.array, + inputs_embeds: mx.array, + input_ids: mx.array, + ) -> mx.array: + """Replace placeholder tokens with actual image feature embeddings.""" + if self.image_token_index is None: + logger.warning("image_token_index not set, cannot merge image features") + return inputs_embeds + + batch_size, seq_len, hidden_dim = inputs_embeds.shape + image_positions = input_ids == self.image_token_index + + if image_features.ndim == 3: + image_features = image_features.reshape(-1, image_features.shape[-1]) + + image_features = image_features.astype(inputs_embeds.dtype) + + batch_outputs = [] + feature_start_idx = 0 + + for batch_idx in range(batch_size): + batch_mask = image_positions[batch_idx] + num_positions = int(mx.sum(batch_mask).item()) + + if num_positions > 0: + batch_features = image_features[ + feature_start_idx : feature_start_idx + num_positions + ] + + if batch_features.shape[0] != num_positions: + raise ValueError( + f"Image token positions ({num_positions}) does not match " + f"image features ({batch_features.shape[0]}) for batch {batch_idx}" + ) + + cumsum = mx.cumsum(batch_mask.astype(mx.int32)) + feature_indices = mx.where(batch_mask, cumsum - 1, 0) + gathered_features = batch_features[feature_indices] + batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1) + batch_output = mx.where( + batch_mask_expanded, gathered_features, inputs_embeds[batch_idx] + ) + feature_start_idx += num_positions + else: + batch_output = inputs_embeds[batch_idx] + + batch_outputs.append(batch_output) + + return mx.stack(batch_outputs, axis=0) + def logits_to_tokens( self, logits: mx.array, @@ -128,29 +375,53 @@ def __call__( block_tables: Optional[mx.array] = None, context_lengths: Optional[mx.array] = None, slot_mapping: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + inputs_embeds: Optional[mx.array] = None, **kwargs, ) -> mx.array: """ + Forward pass through the sharded model. + Args: h_or_tokens: (batch, target_len_padded, D) or (batch, target_len_padded) for prefill, (batch, 1, D) or (batch, 1) for decode. cache: List of layer caches (KVCache or LinearCache). Legacy mode: (key_cache_global, value_cache_global) tuple. - lengths: (batch,) true lengths of each sequence in batch. mask: Optional causal mask for the current segment. - window_size: Optional int, if provided, will use a sliding window attention mask. block_tables: (batch, max_blocks) for PagedAttention. context_lengths: (batch,) for PagedAttention. slot_mapping: (total_tokens,) for PagedAttention. + pixel_values: (batch, C, H, W) or (num_patches, C, H, W) for VLM. + Image pixel values to be processed by vision tower. + inputs_embeds: (batch, seq_len, hidden_dim) Pre-computed embeddings. + If provided, skips embedding and vision processing. + **kwargs: Additional model-specific arguments (e.g., image_grid_thw). + + Returns: + For last shard: logits (batch, seq_len, vocab_size) + For other shards: hidden states (batch, seq_len, hidden_dim) """ h = h_or_tokens target_len = h.shape[1] if self.is_first_shard: - if self.embed_tokens is None: - raise ValueError("embed_tokens is None for the first shard.") - h = self.embed_tokens(h) + if inputs_embeds is not None: + # Use pre-computed embeddings (e.g., from get_input_embeddings) + h = inputs_embeds + else: + if self.embed_tokens is None: + raise ValueError("embed_tokens is None for the first shard.") + + # Check if we need to process vision inputs + if pixel_values is not None and self.is_vlm: + # Use get_input_embeddings for VLM processing + embed_output = self.get_input_embeddings(h, pixel_values, **kwargs) + h = embed_output.inputs_embeds + else: + # Standard text embedding + h = self.embed_tokens(h) + if self.has_norm_in and self.norm_in: h = self.norm_in(h) diff --git a/src/parallax/server/node_chat_http_server.py b/src/parallax/server/node_chat_http_server.py index 6ecbc283..6b8b6ea6 100644 --- a/src/parallax/server/node_chat_http_server.py +++ b/src/parallax/server/node_chat_http_server.py @@ -13,6 +13,7 @@ from starlette.datastructures import State from backend.server.rpc_connection_handler import RPCConnectionHandler +from parallax.server.http_server import validate_kimi_k25_params from parallax_utils.file_util import get_project_root from parallax_utils.logging_config import get_logger @@ -43,6 +44,15 @@ async def get_cluster_status(): async def openai_v1_chat_completions(raw_request: Request): """OpenAI v1/chat/complete post function""" request_data = await raw_request.json() + + # Validate immutable parameter constraints (e.g. Kimi-K2.5) + param_error = validate_kimi_k25_params(request_data) + if param_error is not None: + return JSONResponse( + content={"error": {"message": param_error, "type": "BadRequestError", "code": 400}}, + status_code=400, + ) + request_id = uuid.uuid4() received_ts = time.time() return await v1_chat_completions(request_data, request_id, received_ts) diff --git a/src/parallax/server/request.py b/src/parallax/server/request.py index f7c9bb90..434163bc 100644 --- a/src/parallax/server/request.py +++ b/src/parallax/server/request.py @@ -59,8 +59,9 @@ """ import uuid +from dataclasses import dataclass from enum import Enum -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional from parallax.server.sampling.sampling_params import SamplingParams from parallax_utils.logging_config import get_logger @@ -68,6 +69,49 @@ logger = get_logger(__name__) +@dataclass +class VLMInputs: + """Container for Vision Language Model inputs. + + This is used to pass image/video data through the pipeline. + Only the first peer needs to process pixel_values; subsequent peers + receive pre-computed image embeddings merged into hidden_states. + """ + + pixel_values: Optional[Any] = None + image_grid_thw: Optional[Any] = None + image_token_counts: Optional[List[int]] = None + image_sizes: Optional[List[tuple]] = None + images_processed: bool = False + + def has_images(self) -> bool: + """Check if this request contains image inputs.""" + return self.pixel_values is not None and len(self.pixel_values) > 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "pixel_values": self.pixel_values, + "image_grid_thw": self.image_grid_thw, + "image_token_counts": self.image_token_counts, + "image_sizes": self.image_sizes, + "images_processed": self.images_processed, + } + + @classmethod + def from_dict(cls, data: Optional[Dict[str, Any]]) -> Optional["VLMInputs"]: + """Create from dictionary.""" + if data is None: + return None + return cls( + pixel_values=data.get("pixel_values"), + image_grid_thw=data.get("image_grid_thw"), + image_token_counts=data.get("image_token_counts"), + image_sizes=data.get("image_sizes"), + images_processed=data.get("images_processed", False), + ) + + class RequestStatus(Enum): """Enumeration of possible request statuses for the First Peer.""" @@ -96,6 +140,8 @@ def __init__( routing_table: Optional[List[str]] = [], sampling_params: Optional[SamplingParams] = None, lora_path: Optional[str] = None, + multimodal_params: Optional[Dict] = None, + vlm_inputs: Optional[VLMInputs] = None, ): self.request_id = request_id or str(uuid.uuid4()) self.status = status @@ -109,6 +155,15 @@ def __init__( self.last_updated_time: Optional[float] = None self.lora_id: Optional[str] = None self.lora_path = lora_path + self.multimodal_params = multimodal_params + + # VLM support: structured container for vision inputs + self.vlm_inputs = vlm_inputs + + @property + def has_images(self) -> bool: + """Check if this request contains image inputs.""" + return self.vlm_inputs is not None and self.vlm_inputs.has_images() @property def is_finished(self) -> bool: @@ -161,6 +216,8 @@ def __init__( status: RequestStatus = RequestStatus.PREFILLING, lora_path: Optional[str] = None, return_probs: bool = False, + multimodal_params: Optional[Dict] = None, + vlm_inputs: Optional[VLMInputs] = None, ): if not prompt and not input_ids: raise ValueError("prompt or input_ids cannot be empty.") @@ -171,6 +228,8 @@ def __init__( input_ids=input_ids, sampling_params=sampling_params, lora_path=lora_path, + multimodal_params=multimodal_params, + vlm_inputs=vlm_inputs, ) self.prompt = prompt self.return_probs = return_probs @@ -268,6 +327,7 @@ def __init__( lora_path: Optional[str] = None, token_prob: Optional[float] = None, return_probs: bool = False, + vlm_inputs: Optional[VLMInputs] = None, ): super().__init__( request_id=request_id, @@ -276,6 +336,7 @@ def __init__( input_ids=input_ids, sampling_params=sampling_params, lora_path=lora_path, + vlm_inputs=vlm_inputs, ) # Hidden states from the previous peer's computation. # Shape: @@ -332,6 +393,16 @@ def from_initial_request( else: next_token_id = initial_request.output_ids[-1] + vlm_inputs = None + if initial_request.vlm_inputs is not None: + vlm_inputs = VLMInputs( + pixel_values=None, + image_grid_thw=initial_request.vlm_inputs.image_grid_thw, + image_token_counts=initial_request.vlm_inputs.image_token_counts, + image_sizes=initial_request.vlm_inputs.image_sizes, + images_processed=True, + ) + return IntermediateRequest( request_id=initial_request.request_id, status=initial_request.status, @@ -344,6 +415,7 @@ def from_initial_request( lora_path=lora_path, token_prob=token_prob, return_probs=initial_request.return_probs, + vlm_inputs=vlm_inputs, ) @classmethod @@ -370,6 +442,7 @@ def from_intermediate_request( lora_path=lora_path, token_prob=token_prob, return_probs=old_request.return_probs, + vlm_inputs=old_request.vlm_inputs, # Pass through VLM metadata ) def __repr__(self): diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 69ab5fbd..29b335c4 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -224,6 +224,15 @@ def check_and_update_request_status(self, request: InitialRequest) -> bool: ): request.update_status(RequestStatus.FINISHED_EOS) finished = True + elif ( + not finished + and not request.sampling_params.ignore_eos + and request.sampling_params.stop_token_ids + and last_token_id is not None + and last_token_id in request.sampling_params.stop_token_ids + ): + request.update_status(RequestStatus.FINISHED_EOS) + finished = True elif request.output_length >= request.max_new_tokens: request.update_status(RequestStatus.FINISHED_MAX_LENGTH) finished = True diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 31a92b93..c4af3624 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -7,7 +7,7 @@ import json import pathlib import types -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Type import mlx.core as mx from huggingface_hub import snapshot_download @@ -19,6 +19,7 @@ from mlx_lm.utils import _download, load_config from parallax.server.model import ShardedModel +from parallax.utils.config_utils import get_config_value from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger @@ -29,6 +30,59 @@ } +VLM_TEXT_CONFIG_MAP = { + "qwen3_vl": "qwen3", + "qwen2_vl": "qwen2", + "qwen2_5_vl": "qwen2", + "kimi_vl": "deepseek_v3", +} + +VLM_SPECIAL_PROJECTOR_MAP = { + "llava": ("mlx_vlm.models.llava.llava", "LlavaMultiModalProjector"), + "llava_next": ("mlx_vlm.models.llava_next.llava_next", "LlavaMultiModalProjector"), + "kimi_vl": ("mlx_vlm.models.kimi_vl.kimi_vl", "KimiVLMultiModalProjector"), +} + + +def _get_vlm_classes( + model_type: str, config: Dict[str, Any] +) -> Tuple[Optional[Type[nn.Module]], Optional[Type[nn.Module]], Optional[Dict[str, Any]]]: + """ + Get VLM-specific classes for a given model type. + + Args: + model_type: The model type from config.json + config: Full model configuration + + Returns: + Tuple of (vision_tower_class, projector_class, vision_config) + Returns (None, None, None) if not a VLM model + """ + vision_config = config.get("vision_config") + if vision_config is None: + return None, None, None + + try: + vision_module_path = f"mlx_vlm.models.{model_type}" + vision_module = importlib.import_module(vision_module_path) + vision_tower_class = getattr(vision_module, "VisionModel") + + projector_class = None + if model_type in VLM_SPECIAL_PROJECTOR_MAP: + proj_module_path, proj_class_name = VLM_SPECIAL_PROJECTOR_MAP[model_type] + proj_module = importlib.import_module(proj_module_path) + projector_class = getattr(proj_module, proj_class_name) + logger.info(f"Loaded VLM classes for {model_type}: VisionModel + {proj_class_name}") + else: + logger.info(f"Loaded VLM classes for {model_type}: VisionModel (projector integrated)") + + return vision_tower_class, projector_class, vision_config + + except (ImportError, AttributeError) as e: + logger.warning(f"Failed to load VLM classes for {model_type}: {e}") + return None, None, None + + class MLXModelLoader: """ Handles downloading model assets from Hugging Face (if needed) and loading @@ -85,10 +139,15 @@ def register_block_class(self): if hasattr(entry_class, "get_architecture"): architecture = entry_class.get_architecture() self.block_class_map[architecture] = entry_class - # logger.info(f"Registered {architecture} -> {entry_class.__name__}") + logger.debug(f"Registered {architecture} -> {entry_class.__name__}") else: logger.warning(f"No architecture attribute found in {entry_class.__name__}") + except ImportError as e: + # Log more details for import errors (often missing dependencies) + logger.warning( + f"Failed to import {model_file.name}: {e} (install required dependencies)" + ) except Exception as e: logger.warning(f"Failed to load model from {model_file}: {e}") @@ -253,28 +312,44 @@ def load( if block_class is None: raise ValueError(f"block_class not found for architecture: {architecture}") - num_hidden_layers = config.get("num_hidden_layers", 0) - current_start_layer = self.start_layer if self.start_layer is not None else 0 - current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers - # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") if not model_type: raise ValueError("model_type not found in config.json") - if model_type in MODEL_CLASS_MAP: - model_class = MODEL_CLASS_MAP[model_type] + config_for_args = config + model_class_type = model_type + + if model_type in VLM_TEXT_CONFIG_MAP: + text_config = config.get("text_config", {}) + if text_config: + config_for_args = {**config, **text_config} + if "num_hidden_layers" not in config and "num_hidden_layers" in text_config: + config["num_hidden_layers"] = text_config["num_hidden_layers"] + model_class_type = VLM_TEXT_CONFIG_MAP[model_type] + logger.info( + f"VLM model {model_type} using {model_class_type} ModelArgs with text_config" + ) + + num_hidden_layers = config.get("num_hidden_layers", 0) + current_start_layer = self.start_layer if self.start_layer is not None else 0 + current_end_layer = self.end_layer if self.end_layer is not None else num_hidden_layers + + if model_class_type in MODEL_CLASS_MAP: + model_class = MODEL_CLASS_MAP[model_class_type] else: - model_class = f"mlx_lm.models.{model_type}" + model_class = f"mlx_lm.models.{model_class_type}" try: arch_module = importlib.import_module(model_class) model_args_class = getattr(arch_module, "ModelArgs") - model_args = model_args_class.from_dict(config) + model_args = model_args_class.from_dict(config_for_args) except (ImportError, AttributeError) as e: - raise ValueError(f"Failed to load architecture for model_type '{model_type}'.") from e + raise ValueError( + f"Failed to load architecture for model_type '{model_type}' (using {model_class})." + ) from e dtype = getattr(mx, config.get("torch_dtype", "bfloat16")) @@ -284,6 +359,25 @@ def load( model_id = model_id.split("/")[-1] else: # If it's already a clean name or a local path (take basename) model_id = pathlib.Path(model_id).name + + vision_tower_class, projector_class, vision_config = _get_vlm_classes(model_type, config) + is_vlm = vision_config is not None and vision_tower_class is not None + + image_token_index = ( + config.get("image_token_index") + or config.get("image_token_id") + or config.get("media_placeholder_token_id") + ) + vision_feature_layer = config.get("vision_feature_layer", -2) + vision_feature_select_strategy = config.get("vision_feature_select_strategy", "default") + + if is_vlm: + logger.info( + f"Detected VLM model: {model_type}, " + f"image_token_index={image_token_index}, " + f"vision_feature_layer={vision_feature_layer}" + ) + model_shard = ShardedModel( config=model_args, model_id=model_id, @@ -291,6 +385,13 @@ def load( end_layer=current_end_layer, block_class=block_class, dtype=dtype, + # VLM parameters + vision_config=vision_config if is_vlm else None, + vision_tower_class=vision_tower_class, + multi_modal_projector_class=projector_class, + image_token_index=image_token_index, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, ) weight_files = glob.glob(str(model_path / "model*.safetensors")) @@ -313,6 +414,7 @@ def load( is_first_shard=model_shard.is_first_shard, is_last_shard=model_shard.is_last_shard, config=config, + is_vlm=is_vlm, ) if not weight_files and strict: @@ -321,7 +423,22 @@ def load( # Instead of loading all weights, we iterate through files and keys, # loading only what we need. shard_weights = {} - layer_key_prefix = "model.layers" # Common prefix + + layer_key_prefixes = [ + ("language_model.model.layers.", 3), # mlx-vlm style: parts[3] is layer index + ("model.language_model.layers.", 3), # HF VLM style: parts[3] is layer index + ("model.layers.", 2), # Standard style: parts[2] is layer index + ] + + vlm_weight_prefixes = [ + "vision_tower.", + "vision_model.", + "visual.", # Qwen-VL style + "multi_modal_projector.", + "mm_projector.", + ] + + tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) for file_idx, wf in enumerate(weight_files): logger.debug( @@ -333,44 +450,85 @@ def load( is_needed = False remapped_key = None - # Check if the key belongs to the shard and remap it - if ( - model_shard.is_first_shard - and "embed_tokens" in key - and key.startswith("model.") - ): + if model_shard.is_first_shard and "embed_tokens" in key: is_needed = True - remapped_key = key.replace("model.", "", 1) - if model_shard.is_last_shard and config.get("tie_word_embeddings", False): - # Also add lm_head mapping for tied embeddings + if "language_model.model.embed_tokens" in key: + remapped_key = key.replace("language_model.model.", "") + elif "language_model.embed_tokens" in key: + remapped_key = key.split("language_model.")[-1] + elif key.startswith("model."): + remapped_key = key.replace("model.", "", 1) + else: + remapped_key = key + if model_shard.is_last_shard and tie_word_embeddings: lm_head_key = remapped_key.replace("embed_tokens", "lm_head") shard_weights[lm_head_key] = f[key] + elif model_shard.is_last_shard: - if "model.norm" in key: - is_needed = True - remapped_key = key.replace("model.", "", 1) + if ".norm." in key or key.endswith(".norm.weight"): + is_final_norm = ( + "language_model.model.norm" in key + or "language_model.norm" in key + or (key.startswith("model.norm") and "layers" not in key) + ) + if is_final_norm: + is_needed = True + if "language_model.model.norm" in key: + remapped_key = key.replace("language_model.model.", "") + elif "language_model.norm" in key: + remapped_key = key.split("language_model.")[-1] + else: + remapped_key = key.replace("model.", "", 1) if "lm_head" in key: is_needed = True - remapped_key = key - elif ( - config.get("tie_word_embeddings", False) - and "embed_tokens" in key - and key.startswith("model.embed_tokens") - ): + if key.startswith("language_model."): + remapped_key = key.replace("language_model.", "") + else: + remapped_key = key + elif tie_word_embeddings and "embed_tokens" in key: is_needed = True - remapped_key = key.replace("model.", "", 1).replace( - "embed_tokens", "lm_head" - ) - if layer_key_prefix in key: - try: - parts = key.split(".") - layer_idx = int(parts[2]) - if current_start_layer <= layer_idx < current_end_layer: + if "language_model.model.embed_tokens" in key: + remapped_key = key.replace("language_model.model.", "").replace( + "embed_tokens", "lm_head" + ) + elif "language_model.embed_tokens" in key: + remapped_key = key.split("language_model.")[-1].replace( + "embed_tokens", "lm_head" + ) + else: + remapped_key = key.replace("model.", "", 1).replace( + "embed_tokens", "lm_head" + ) + + # VLM: Load vision tower and projector weights on first shard + if model_shard.is_first_shard and is_vlm: + for prefix in vlm_weight_prefixes: + if key.startswith(prefix): is_needed = True - local_layer_idx = layer_idx - current_start_layer - remapped_key = f"layers.{local_layer_idx}.{'.'.join(parts[3:])}" - except (ValueError, IndexError): - continue + remapped_key = key + break + if key.startswith(f"model.{prefix}"): + is_needed = True + remapped_key = key.replace("model.", "", 1) + break + + if not is_needed: + for layer_prefix, layer_idx_pos in layer_key_prefixes: + if layer_prefix in key: + try: + parts = key.split(".") + layer_idx = int(parts[layer_idx_pos]) + if current_start_layer <= layer_idx < current_end_layer: + is_needed = True + local_layer_idx = layer_idx - current_start_layer + # Remap to layers.{local_idx}.{rest} + rest_parts = parts[layer_idx_pos + 1 :] + remapped_key = ( + f"layers.{local_layer_idx}.{'.'.join(rest_parts)}" + ) + break + except (ValueError, IndexError): + continue # If the key is needed, load only that tensor from the file if is_needed: @@ -431,19 +589,52 @@ def class_predicate(p, m): class_predicate=class_predicate, ) - model_shard.load_weights(list(shard_weights.items()), strict=strict) + # Log weight keys before loading + logger.info( + f"Loading {len(shard_weights)} weights. Sample keys: {list(shard_weights.keys())[:20]}" + ) + + # Try strict mode first to catch any mismatch, then fall back to non-strict + try: + model_shard.load_weights(list(shard_weights.items()), strict=True) + except Exception as e: + logger.warning(f"Strict weight loading failed: {e}. Retrying with strict=False.") + model_shard.load_weights(list(shard_weights.items()), strict=False) model_shard.shard_layers() + # Log VLM-specific weight loading info + if is_vlm and model_shard.is_first_shard: + vlm_weight_count = sum( + 1 + for k in shard_weights.keys() + if any( + k.startswith(p) + for p in [ + "vision_tower", + "vision_model", + "visual", + "multi_modal_projector", + "mm_projector", + ] + ) + ) + logger.info(f"Loaded {vlm_weight_count} VLM weights (vision_tower + projector)") + + logger.info(f"Total weights loaded: {len(shard_weights)}") + shard_weights.clear() mx.eval(model_shard.parameters()) # Synchronize processes to avoid timeout mx.eval(mx.distributed.all_sum(mx.array(1.0))) model_shard.eval() + + vlm_info = f", VLM={is_vlm}" if is_vlm else "" logger.info( - "Successfully loaded model shard (layers [%d-%d)), memory usage: %.3f GB", + "Successfully loaded model shard (layers [%d-%d)%s), memory usage: %.3f GB", current_start_layer, current_end_layer, + vlm_info, mx.get_active_memory() / 1024**3, ) return model_shard, config, tokenizer diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 56720040..819b9d1f 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -1,15 +1,12 @@ """ Store information about a SGLang batch. -The following is the flow of data structures for a batch in SGLang: - -ScheduleBatch -> ModelWorkerBatch -> ForwardBatch """ from types import SimpleNamespace -from typing import List, Optional +from typing import Any, List, Optional import torch -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import MultimodalInputs, Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import ( @@ -23,6 +20,7 @@ from parallax.server.sampling.sampling_params import ( SamplingParams as ParallaxSamplingParams, ) +from parallax.sglang.multimodal_utils import process_multimodal_request from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -48,20 +46,57 @@ def transform_sampling_params_to_sglang(old_params: ParallaxSamplingParams) -> S def transform_requests_to_sglang( - old_requests: List[Request], page_tree_cache: Optional[PageRadixCache] = None + old_requests: List[Request], + page_tree_cache: Optional[PageRadixCache] = None, + processor: Optional[Any] = None, + hf_config: Optional[dict] = None, + tokenizer: Optional[Any] = None, ) -> List[Req]: - """Transforms Parallax Request to SGLang.Req format""" reqs = [] + mm_config = hf_config or {} + for old_req in old_requests: sampling_params = transform_sampling_params_to_sglang(old_req.sampling_params) req = Req( rid=old_req.request_id, origin_input_text="", - origin_input_ids=old_req.input_ids, + origin_input_ids=list(old_req.input_ids), sampling_params=sampling_params, lora_id=old_req.lora_id, ) + if old_req.multimodal_params is not None: + try: + if "mm_items" in old_req.multimodal_params: + req.multimodal_inputs = MultimodalInputs.from_dict(old_req.multimodal_params) + + elif "images" in old_req.multimodal_params and processor is not None: + image_urls = old_req.multimodal_params["images"] + multimodal_inputs, padded_input_ids = process_multimodal_request( + image_urls=image_urls, + input_ids=old_req.input_ids, + processor=processor, + tokenizer=tokenizer, + mm_config=mm_config, + ) + if multimodal_inputs is not None: + req.multimodal_inputs = multimodal_inputs + req.origin_input_ids = padded_input_ids + + else: + logger.warning( + f"Cannot process multimodal_params: no 'mm_items' or 'images' key, " + f"or processor not available. " + f"Params keys: {old_req.multimodal_params.keys()}, Processor: {processor is not None}" + ) + # Don't assign raw dict - SGLang expects MultimodalInputs object + req.multimodal_inputs = None + + except Exception as e: + logger.exception(f"Failed to construct MultimodalInputs: {e}") + # Don't assign raw dict - SGLang expects MultimodalInputs object + req.multimodal_inputs = None + # Debug: Log before cache lookup if page_tree_cache is not None: logger.debug( @@ -92,12 +127,16 @@ def form_sgl_batch_prefill( requests: List[Request], model_runner: ModelRunner, page_tree_cache: Optional[PageRadixCache] = None, + processor: Optional[Any] = None, + hf_config: Optional[dict] = None, + tokenizer: Optional[Any] = None, ) -> ForwardBatch: - """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" - sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache) + sgl_reqs = transform_requests_to_sglang( + requests, page_tree_cache, processor, hf_config, tokenizer + ) - def dummy_evict(*args): + def dummy_function(*args): pass dummy_tree_cache = SimpleNamespace( @@ -105,8 +144,12 @@ def dummy_evict(*args): device=model_runner.device, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, evictable_size=0, + sliding_window_size=0, ) - dummy_tree_cache.evict = dummy_evict + dummy_tree_cache.evict = dummy_function + dummy_tree_cache.supports_swa = dummy_function + dummy_tree_cache.is_chunk_cache = dummy_function + dummy_tree_cache.supports_mamba = dummy_function schedule_batch = ScheduleBatch.init_new( reqs=sgl_reqs, req_to_token_pool=model_runner.req_to_token_pool, diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0c2667a9..f0edc61c 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -40,6 +40,7 @@ from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( set_layer_range_for_filtering, ) +from parallax.utils.config_utils import ModelConfigAccessor from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -74,8 +75,14 @@ def __init__( """Add pp_start_layer and pp_end_layer for decentralized model""" self.pp_start_layer = pp_start_layer self.pp_end_layer = pp_end_layer - num_hidden_layers = model_config.hf_config.num_hidden_layers - set_layer_range_for_filtering(pp_start_layer, pp_end_layer, num_hidden_layers) + config_accessor = ModelConfigAccessor(model_config.hf_config) + num_hidden_layers = config_accessor.get_num_hidden_layers() + if num_hidden_layers is None: + raise ValueError("num_hidden_layers is required but not found in model config") + is_vlm = config_accessor.is_vlm + set_layer_range_for_filtering( + pp_start_layer, pp_end_layer, num_hidden_layers, is_vlm=is_vlm + ) super().__init__( model_config=model_config, @@ -278,6 +285,7 @@ def initialize_sgl_model_runner( - model_runner: SGL model runner - config: model config driven by mlx-lm - tokenizer: tokenizer driven by mlx-lm + - processor: optional processor for multimodal models """ apply_parallax_sglang_monkey_patch() @@ -302,6 +310,16 @@ def initialize_sgl_model_runner( tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") + # Load processor if available (for multimodal models) + processor = None + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + logger.info(f"Loaded processor for model {model_path}") + except Exception as e: + logger.debug(f"No processor loaded (normal for text-only models): {e}") + if nccl_port is None: nccl_port = random.randint(4000, 5000) @@ -374,7 +392,7 @@ def initialize_sgl_model_runner( dp_rank=dp_rank, dp_size=dp_size, ) - return model_runner, config, tokenizer + return model_runner, config, tokenizer, processor def refit_sgl_model( diff --git a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py index 7bd48082..261bbc20 100644 --- a/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py +++ b/src/parallax/sglang/monkey_patch_utils/weight_loader_filter.py @@ -11,11 +11,17 @@ _layer_range_cache = {} -def set_layer_range_for_filtering(pp_start_layer: int, pp_end_layer: int, num_hidden_layers: int): +def set_layer_range_for_filtering( + pp_start_layer: int, + pp_end_layer: int, + num_hidden_layers: int, + is_vlm: bool = False, +): global _layer_range_cache _layer_range_cache["pp_start_layer"] = pp_start_layer _layer_range_cache["pp_end_layer"] = pp_end_layer _layer_range_cache["num_hidden_layers"] = num_hidden_layers + _layer_range_cache["is_vlm"] = is_vlm def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: @@ -24,6 +30,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: pp_start_layer = _layer_range_cache.get("pp_start_layer") pp_end_layer = _layer_range_cache.get("pp_end_layer") num_hidden_layers = _layer_range_cache.get("num_hidden_layers") + is_vlm = _layer_range_cache.get("is_vlm", False) if pp_start_layer is None or pp_end_layer is None: logger.debug("No layer range set, loading all weight files") @@ -43,6 +50,7 @@ def _filter_weight_files_by_cache(hf_weights_files: List[str]) -> List[str]: end_layer=pp_end_layer, is_first_shard=is_first_shard, is_last_shard=is_last_shard, + is_vlm=is_vlm, ) return filtered_files diff --git a/src/parallax/sglang/multimodal_utils.py b/src/parallax/sglang/multimodal_utils.py new file mode 100644 index 00000000..d604d5a0 --- /dev/null +++ b/src/parallax/sglang/multimodal_utils.py @@ -0,0 +1,414 @@ +""" +Multimodal utilities for SGLang backend. +""" + +import logging +from io import BytesIO +from typing import Any, List, Optional, Tuple + +import requests +import torch +from PIL import Image +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) + +logger = logging.getLogger(__name__) + + +def load_image(url: Any) -> Image.Image: + import base64 + + if isinstance(url, dict): + url = url.get("url") + if not isinstance(url, str): + raise ValueError(f"Unsupported image url type: {type(url)}") + + if url.startswith("http"): + response = requests.get(url, timeout=10) + response.raise_for_status() + image_data = BytesIO(response.content) + return Image.open(image_data).convert("RGB") + elif url.startswith("data:image"): + header, encoded = url.split(",", 1) + # Strip whitespace from base64 encoded data (some clients add spaces after comma) + encoded = encoded.strip() + image_data = BytesIO(base64.b64decode(encoded)) + return Image.open(image_data).convert("RGB") + else: + return Image.open(url).convert("RGB") + + +def process_images( + image_urls: List[Any], + processor: Any, + input_text: str = "", + mm_config: Optional[dict] = None, +) -> tuple[List[MultimodalDataItem], Optional[torch.Tensor], List[int]]: + if not image_urls: + return [], None, [] + + mm_items = [] + images = [] + + for url in image_urls: + try: + image = load_image(url) + images.append(image) + logger.debug(f"Loaded image: size={image.size}, mode={image.mode}") + except Exception as e: + logger.exception(f"Failed to load image {url}: {e}") + continue + + if not images: + return [], None, [] + + try: + # Check if this is a Kimi K2.5 processor (has different interface) + processor_class_name = processor.__class__.__name__ + if processor_class_name == "KimiK25Processor": + # Kimi K2.5 requires special handling: + # 1. Expand image tokens based on actual image size + # 2. Use 'medias' parameter instead of 'images' + + # The image token for Kimi K2.5 is <|media_pad|> + image_token = "<|media_pad|>" + + # Expand single image tokens to the correct number of tokens + if image_token in input_text and hasattr(processor, "media_processor"): + parts = input_text.split(image_token) + result = [parts[0]] + for i, (image, part) in enumerate(zip(images, parts[1:])): + try: + # Calculate how many tokens this image needs + num_tokens = processor.media_processor.media_tokens_calculator( + {"type": "image", "image": image} + ) + logger.debug(f"Kimi K2.5: Image {i} expanded to {num_tokens} tokens") + except Exception as e: + logger.warning( + f"Failed to calculate media tokens for image {i}: {e}, using 1" + ) + num_tokens = 1 + result.append(image_token * num_tokens + part) + input_text = "".join(result) + logger.debug(f"Kimi K2.5: Expanded input_text length: {len(input_text)}") + + # Kimi K2.5 requires 'medias' parameter with specific format + medias = [{"type": "image", "image": img} for img in images] + inputs = processor(medias=medias, text=input_text, return_tensors="pt") + else: + # Standard HuggingFace processor interface + inputs = processor(text=input_text, images=images, return_tensors="pt") + + if inputs is None: + logger.error("Processor returned None") + return [], None, [] + + # Debug: Log all keys returned by the processor + logger.debug(f"Processor output keys: {list(inputs.keys())}") + for key, value in inputs.items(): + if hasattr(value, "shape"): + logger.debug(f" {key}: shape={value.shape}, dtype={value.dtype}") + elif hasattr(value, "__len__"): + logger.debug(f" {key}: len={len(value)}, type={type(value)}") + else: + logger.debug(f" {key}: {type(value)}") + + pixel_values = inputs.get("pixel_values") + if pixel_values is None: + logger.error("Processor output missing pixel_values") + return [], None, [] + + logger.debug( + f"pixel_values shape: {pixel_values.shape}, dtype: {pixel_values.dtype}, device: {pixel_values.device}" + ) + logger.debug( + f"pixel_values stats: min={pixel_values.min().item():.4f}, max={pixel_values.max().item():.4f}, mean={pixel_values.mean().item():.4f}" + ) + + expanded_input_ids = inputs.get("input_ids") + if expanded_input_ids is not None: + expanded_input_ids = expanded_input_ids.flatten().tolist() + else: + expanded_input_ids = [] + + logger.debug(f"Processor expanded input_ids length: {len(expanded_input_ids)}") + + # Handle different field names: Kimi K2.5 uses 'grid_thws', others use 'image_grid_thw' + is_kimi_k25 = processor_class_name == "KimiK25Processor" + image_grid_thw = inputs.get("image_grid_thw") + if image_grid_thw is None: + image_grid_thw = inputs.get("grid_thws") + image_sizes = inputs.get("image_sizes") + + logger.debug(f"image_grid_thw: {image_grid_thw}, is_kimi_k25: {is_kimi_k25}") + + # Determine the correct field name for grid data + # Kimi K2.5 expects 'grid_thws', others expect 'image_grid_thw' + grid_field_name = "grid_thws" if is_kimi_k25 else "image_grid_thw" + + model_specific_data = {} + if image_grid_thw is not None: + model_specific_data[grid_field_name] = image_grid_thw + if image_sizes is not None: + model_specific_data["image_sizes"] = image_sizes + + if image_grid_thw is not None and len(image_grid_thw) == len(images): + num_images = len(images) + + patches_per_image = [] + for i in range(num_images): + grid = image_grid_thw[i] + if isinstance(grid, torch.Tensor): + num_patches = int(torch.prod(grid).item()) + else: + num_patches = int(torch.prod(torch.tensor(grid)).item()) + patches_per_image.append(num_patches) + + patch_start = 0 + for i in range(num_images): + num_patches = patches_per_image[i] + item_pixel_values = pixel_values[patch_start : patch_start + num_patches] + item_grid_thw = image_grid_thw[i : i + 1] + + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=item_pixel_values, + model_specific_data={ + grid_field_name: item_grid_thw, + }, + ) + mm_items.append(item) + patch_start += num_patches + else: + item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=pixel_values, + model_specific_data=model_specific_data, + ) + mm_items.append(item) + + return mm_items, image_grid_thw, expanded_input_ids + + except Exception as e: + logger.exception(f"Failed to process images: {e}") + return [], None, [] + + +def get_image_token_offsets( + input_ids: List[int], + image_token_id: Optional[int], + vision_start_id: Optional[int] = None, + vision_end_id: Optional[int] = None, +) -> List[Tuple[int, int]]: + offsets = [] + + if vision_start_id is not None and vision_end_id is not None: + start_indices = [i for i, tok in enumerate(input_ids) if tok == vision_start_id] + end_indices = [i for i, tok in enumerate(input_ids) if tok == vision_end_id] + + for start, end in zip(start_indices, end_indices): + if start < end: + offsets.append((start + 1, end - 1)) + elif image_token_id is not None: + start = None + for i, tok in enumerate(input_ids): + if tok == image_token_id: + if start is None: + start = i + elif start is not None: + offsets.append((start, i - 1)) + start = None + if start is not None: + offsets.append((start, len(input_ids) - 1)) + + return offsets + + +def compute_mrope_positions( + input_ids: List[int], + image_grid_thw: Optional[torch.Tensor], + mm_config: dict, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + + import logging + + logger = logging.getLogger(__name__) + + seq_len = len(input_ids) + + def get_default_positions(): + positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).repeat(3, 1) + delta = torch.zeros(1, dtype=torch.long) + return positions, delta + + if image_grid_thw is None: + return get_default_positions() + + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + video_token_id = mm_config.get("video_token_id") + model_type = mm_config.get("model_type") + vision_config = mm_config.get("vision_config", {}) + + spatial_merge_size = ( + vision_config.get("spatial_merge_size") if isinstance(vision_config, dict) else None + ) + tokens_per_second = ( + vision_config.get("tokens_per_second") if isinstance(vision_config, dict) else None + ) + + if image_token_id is None or spatial_merge_size is None: + logger.debug( + f"Missing mrope config: image_token_id={image_token_id}, " + f"spatial_merge_size={spatial_merge_size}. Using default positions." + ) + return get_default_positions() + + try: + from sglang.srt.layers.rotary_embedding import MRotaryEmbedding + + input_ids_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=spatial_merge_size, + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_id, + model_type=model_type, + input_ids=input_ids_tensor, + image_grid_thw=image_grid_thw, + tokens_per_second=tokens_per_second, + ) + + mrope_positions = mrope_positions.squeeze(1) + return mrope_positions, mrope_position_delta + + except Exception as e: + logger.warning(f"Failed to compute mrope_positions: {e}. Using default positions.") + return get_default_positions() + + +def prepare_sglang_multimodal_inputs( + mm_items: List[MultimodalDataItem], + image_grid_thw: Optional[torch.Tensor] = None, + mm_config: Optional[dict] = None, + input_ids: Optional[List[int]] = None, +) -> MultimodalInputs: + mm_config = mm_config or {} + + # Extract token IDs from config + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + vision_end_id = mm_config.get("vision_end_token_id") + video_token_id = mm_config.get("video_token_id") + audio_token_id = mm_config.get("audio_token_id") + + for item in mm_items: + if item.pad_value is None: + item.set_pad_value() + + mrope_positions = None + mrope_position_delta = None + + if input_ids is not None: + mrope_positions, mrope_position_delta = compute_mrope_positions( + input_ids, image_grid_thw, mm_config + ) + + return MultimodalInputs( + mm_items=mm_items, + im_token_id=image_token_id, + im_start_id=vision_start_id, + im_end_id=vision_end_id, + video_token_id=video_token_id, + audio_token_id=audio_token_id, + mrope_positions=mrope_positions, + mrope_position_delta=mrope_position_delta, + ) + + +def process_multimodal_request( + image_urls: List[Any], + input_ids: List[int], + processor: Any, + tokenizer: Any, + mm_config: dict, +) -> Tuple[Optional[MultimodalInputs], List[int]]: + input_text = "" + if tokenizer is not None: + try: + input_text = tokenizer.decode(input_ids, skip_special_tokens=False) + logger.debug(f"Decoded input text (length={len(input_text)}): {input_text[:100]}...") + except Exception as e: + logger.warning(f"Failed to decode input_ids: {e}") + + mm_items, image_grid_thw, expanded_input_ids = process_images( + image_urls, + processor, + input_text=input_text, + mm_config=mm_config, + ) + + if not mm_items: + return None, list(input_ids) + + if expanded_input_ids and len(expanded_input_ids) > len(input_ids): + logger.debug(f"Using expanded input_ids: {len(input_ids)} -> {len(expanded_input_ids)}") + input_ids_for_offsets = expanded_input_ids + else: + input_ids_for_offsets = list(input_ids) + + image_token_id = mm_config.get("image_token_id") + vision_start_id = mm_config.get("vision_start_token_id") + vision_end_id = mm_config.get("vision_end_token_id") + + offsets = get_image_token_offsets( + input_ids_for_offsets, + image_token_id, + vision_start_id, + vision_end_id, + ) + + if len(offsets) == len(mm_items): + for item, offset in zip(mm_items, offsets): + item.offsets = [offset] + elif len(offsets) > 0: + for item in mm_items: + item.offsets = offsets + + for item in mm_items: + item.set_pad_value() + + combined_grid_thw = None + if image_grid_thw is not None: + combined_grid_thw = image_grid_thw + else: + grids = [] + for item in mm_items: + grid = item.model_specific_data.get("image_grid_thw") + if grid is not None: + grids.append(grid) + if grids: + combined_grid_thw = torch.cat(grids, dim=0) + + multimodal_inputs = prepare_sglang_multimodal_inputs( + mm_items=mm_items, + image_grid_thw=combined_grid_thw, + mm_config=mm_config, + input_ids=input_ids_for_offsets, + ) + + padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens() + padded_input_ids = padding_pattern.pad_input_tokens(input_ids_for_offsets, multimodal_inputs) + + logger.debug( + f"Successfully processed {len(mm_items)} images, " + f"offsets={offsets}, padded_input_ids_len={len(padded_input_ids)}" + ) + + return multimodal_inputs, padded_input_ids diff --git a/src/parallax/utils/config_utils.py b/src/parallax/utils/config_utils.py new file mode 100644 index 00000000..7e65fe50 --- /dev/null +++ b/src/parallax/utils/config_utils.py @@ -0,0 +1,239 @@ +""" +Configuration utilities for handling model configs. + +Provides VLM-aware config access for models with text_config/vision_config structure. +""" + +from typing import Any, List, Optional + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ModelConfigAccessor: + """ + VLM-aware model configuration accessor. + + VLM (Vision-Language Models) typically have a nested config structure: + - text_config: Contains text model parameters (num_hidden_layers, eos_token_id, etc.) + - vision_config: Contains vision encoder parameters + + This class provides unified access to config values, automatically handling + the nested structure for VLM models. + """ + + def __init__(self, config: Any): + """ + Initialize the config accessor. + + Args: + config: Model configuration (dict or object with attributes) + """ + self._config = config + self._is_vlm: Optional[bool] = None + + @property + def is_vlm(self) -> bool: + """Check if the model is a VLM (has both text_config and vision_config).""" + if self._is_vlm is None: + text_config = self._raw_get("text_config") + vision_config = self._raw_get("vision_config") + self._is_vlm = text_config is not None and vision_config is not None + return self._is_vlm + + @property + def text_config(self) -> Optional[Any]: + """Get the text_config if available.""" + return self._raw_get("text_config") + + @property + def vision_config(self) -> Optional[Any]: + """Get the vision_config if available.""" + return self._raw_get("vision_config") + + def _raw_get(self, key: str, default: Any = None) -> Any: + """ + Low-level config access without VLM-aware logic. + + Args: + key: Configuration key + default: Default value if key not found + + Returns: + Config value or default + """ + if isinstance(self._config, dict): + return self._config.get(key, default) + return getattr(self._config, key, default) + + def _get_from_subconfig(self, subconfig: Any, key: str) -> Optional[Any]: + """Get a value from a subconfig (dict or object).""" + if subconfig is None: + return None + if isinstance(subconfig, dict): + return subconfig.get(key) + return getattr(subconfig, key, None) + + def get( + self, + key: str, + default: Any = None, + fallback_keys: Optional[List[str]] = None, + prefer_text_config: bool = False, + ) -> Any: + """ + Get a configuration value with VLM-aware logic. + + For VLM models, text-related parameters (num_hidden_layers, eos_token_id, etc.) + are typically stored in 'text_config' rather than at the root level. + + Args: + key: The primary key to look up. + default: Default value if key is not found anywhere. + fallback_keys: Alternative keys to try if the primary key is not found. + prefer_text_config: If True and this is a VLM model, look in text_config first. + + Returns: + The configuration value or default. + """ + # For VLM models, prefer text_config for text-related parameters + if prefer_text_config and self.is_vlm: + value = self._get_from_subconfig(self.text_config, key) + if value is not None: + return value + + # Try primary key at root level + value = self._raw_get(key) + if value is not None: + return value + + # Try fallback keys at root level + if fallback_keys: + for fallback_key in fallback_keys: + value = self._raw_get(fallback_key) + if value is not None: + return value + + # For non-VLM models, also check text_config as fallback + # (some models might have this structure without being full VLMs) + if not self.is_vlm and prefer_text_config: + value = self._get_from_subconfig(self.text_config, key) + if value is not None: + return value + + return default + + def get_num_hidden_layers(self) -> Optional[int]: + """Get the number of hidden layers (common parameter).""" + return self.get( + "num_hidden_layers", + fallback_keys=["n_layer", "num_layers"], + prefer_text_config=True, + ) + + def get_eos_token_id(self) -> Optional[int]: + """Get the EOS token ID.""" + return self.get("eos_token_id", prefer_text_config=True) + + def build_mm_config(self) -> dict: + """ + Build the multimodal configuration dictionary. + + Returns: + Dictionary containing multimodal-related config values. + """ + vision_config_raw = self._raw_get("vision_config", {}) + + # Normalize vision_config to dict format + if vision_config_raw is None: + vision_config = {} + elif isinstance(vision_config_raw, dict): + vision_config = vision_config_raw + else: + # Convert object-style config to dict + vision_config = { + "spatial_merge_size": getattr(vision_config_raw, "spatial_merge_size", None), + "tokens_per_second": getattr(vision_config_raw, "tokens_per_second", None), + } + + # Get image_token_id with fallbacks for different models + # Kimi K2.5 uses 'media_placeholder_token_id' instead of 'image_token_id' + image_token_id = self._raw_get("image_token_id") + if image_token_id is None: + image_token_id = self._raw_get("media_placeholder_token_id") + + return { + "model_type": self._raw_get("model_type"), + "image_token_id": image_token_id, + "vision_start_token_id": self._raw_get("vision_start_token_id"), + "vision_end_token_id": self._raw_get("vision_end_token_id"), + "video_token_id": self._raw_get("video_token_id"), + "audio_token_id": self._raw_get("audio_token_id"), + "vision_config": vision_config, + } + + +# ============================================================================ +# Convenience functions for simple use cases +# ============================================================================ + + +def is_vlm_model(config: Any) -> bool: + """ + Check if the config represents a VLM (Vision-Language Model). + + VLM models have both text_config and vision_config. + + Args: + config: Model configuration (dict or object) + + Returns: + True if the model is a VLM + """ + if isinstance(config, dict): + text_config = config.get("text_config") + vision_config = config.get("vision_config") + else: + text_config = getattr(config, "text_config", None) + vision_config = getattr(config, "vision_config", None) + + return text_config is not None and vision_config is not None + + +def get_config_value(config: Any, key: str, default: Any = None) -> Any: + """ + Get config value with text_config fallback for VLM models. + + This is a simple function interface for cases where you don't need + the full ModelConfigAccessor functionality. + + Args: + config: Model configuration (dict or object) + key: Configuration key to look up + default: Default value if not found + + Returns: + Configuration value or default + """ + # Try root level first + if isinstance(config, dict): + value = config.get(key) + else: + value = getattr(config, key, None) + + if value is not None: + return value + + # Fallback to text_config (for VLM models) + if isinstance(config, dict): + text_config = config.get("text_config", {}) + else: + text_config = getattr(config, "text_config", {}) + + if isinstance(text_config, dict): + return text_config.get(key, default) + elif text_config is not None: + return getattr(text_config, key, default) + + return default diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index 1d5f9d6c..99dd9cb9 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -3,12 +3,15 @@ """ import json +import logging import uuid from dataclasses import dataclass from functools import partial from json import JSONDecodeError from typing import Any, Callable, Dict, List, Optional, Tuple +logger = logging.getLogger(__name__) + from mlx_lm.tokenizer_utils import ( BPEStreamingDetokenizer, NaiveStreamingDetokenizer, @@ -128,6 +131,36 @@ def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=No return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) +def get_tool_call_stop_token_ids(tokenizer) -> List[int]: + """Return token IDs that should act as *stop tokens* for tool call generation. + + When the model generates one of these tokens the scheduler should treat it + as end-of-sequence so that the HTTP handler can inspect the generated text + and extract tool calls. + + Note: tool call *parsing* (``has_tool_calling``, ``tool_parser``, etc.) is + handled automatically by the updated ``mlx-lm`` ``TokenizerWrapper``. + This function only provides the stop-token IDs that the parallax scheduler + needs to halt generation at tool-call boundaries. + """ + stop_ids: List[int] = [] + _get_vocab = getattr(tokenizer, "get_vocab", None) + vocab = _get_vocab() if _get_vocab else {} + + # Markers whose token IDs should halt generation + markers = [ + "<|tool_calls_section_end|>", # Kimi K2 / K2.5 + "<|im_end|>", # common chat turn-end token + ] + + for marker in markers: + token_id = vocab.get(marker) + if token_id is not None: + stop_ids.append(token_id) + + return list(set(stop_ids)) + + @dataclass class ToolCallState: has_tool_calling: bool @@ -178,14 +211,51 @@ def _format_tool_call(self, tool_call: Dict[str, Any]): self.tool_call_idx += 1 return out + def _get_available_tool_names(self) -> Optional[set]: + """Extract the set of valid function names from the tools list.""" + if not self.tools: + return None + names = set() + for t in self.tools: + try: + names.add(t["function"]["name"]) + except (KeyError, TypeError): + continue + return names if names else None + + def _validate_tool_name(self, tool_call: Dict[str, Any], valid_names: Optional[set]) -> bool: + """Check if a parsed tool call's function name is in the available tools.""" + if valid_names is None: + return True # No tools list to validate against + name = tool_call.get("name", "") + if name not in valid_names: + logger.warning( + f"Tool call rejected: function '{name}' not in available tools {valid_names}" + ) + return False + return True + def _parse_tool_text(self, tool_text: str) -> Tuple[List[Dict[str, Any]], Optional[str]]: try: parsed = self.tool_parser(tool_text, self.tools) except Exception: fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" return [], fallback_text + + valid_names = self._get_available_tool_names() + if isinstance(parsed, list): - return [self._format_tool_call(tc) for tc in parsed], None + # Filter out tool calls with hallucinated function names + valid_calls = [tc for tc in parsed if self._validate_tool_name(tc, valid_names)] + if not valid_calls: + fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" + return [], fallback_text + return [self._format_tool_call(tc) for tc in valid_calls], None + + # Single tool call + if not self._validate_tool_name(parsed, valid_names): + fallback_text = f"{self.tool_call_start}{tool_text}{self.tool_call_end}" + return [], fallback_text return [self._format_tool_call(parsed)], None def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]]: @@ -203,7 +273,6 @@ def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]] if start_pos > idx: output_chunks.append(segment[idx:start_pos]) self.in_tool_call = True - self.made_tool_call = True self.tool_text = "" idx = start_pos + len(self.tool_call_start) else: @@ -215,6 +284,7 @@ def extract_from_segment(self, segment: str) -> Tuple[str, List[Dict[str, Any]]] parsed_calls, fallback_text = self._parse_tool_text(self.tool_text) if parsed_calls: new_tool_calls.extend(parsed_calls) + self.made_tool_call = True if fallback_text: output_chunks.append(fallback_text) self.tool_text = "" diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 2190d9b0..0cf5a050 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -25,7 +25,11 @@ def is_mps_available(): def is_metal_available(): - """Check if MLX Metal backend is available""" + """Check if MLX Metal backend is available (macOS only)""" + import sys + + if sys.platform != "darwin": + return False try: import mlx.core as mx diff --git a/src/parallax/utils/vlm_utils.py b/src/parallax/utils/vlm_utils.py new file mode 100644 index 00000000..87ca2bf2 --- /dev/null +++ b/src/parallax/utils/vlm_utils.py @@ -0,0 +1,285 @@ +""" +VLM (Vision Language Model) utilities for MLX backend. + +Provides image loading and preprocessing functionality for multimodal models. +""" + +import base64 +from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + +# Lazy imports for optional dependencies +_PIL_AVAILABLE = None +_REQUESTS_AVAILABLE = None + + +def _check_pil(): + global _PIL_AVAILABLE + if _PIL_AVAILABLE is None: + try: + _PIL_AVAILABLE = True + except ImportError: + _PIL_AVAILABLE = False + return _PIL_AVAILABLE + + +def _check_requests(): + global _REQUESTS_AVAILABLE + if _REQUESTS_AVAILABLE is None: + try: + _REQUESTS_AVAILABLE = True + except ImportError: + _REQUESTS_AVAILABLE = False + return _REQUESTS_AVAILABLE + + +def load_image(source: Any) -> "Image.Image": + """ + Load an image from various sources. + + Supports: + - URL (http/https) + - Local file path + - Base64 encoded data URL + - PIL Image object (pass through) + - Dict with "url" key + + Args: + source: Image source (URL, path, base64, or PIL Image) + + Returns: + PIL Image in RGB format + + Raises: + ImportError: If PIL is not installed + ValueError: If source type is not supported + """ + if not _check_pil(): + raise ImportError( + "PIL (Pillow) is required for image loading. Install with: pip install Pillow" + ) + + from PIL import Image + + # Handle dict with url key + if isinstance(source, dict): + source = source.get("url") + + # Pass through PIL Image + if isinstance(source, Image.Image): + return source.convert("RGB") + + if not isinstance(source, str): + raise ValueError(f"Unsupported image source type: {type(source)}") + + # Load from URL + if source.startswith("http://") or source.startswith("https://"): + if not _check_requests(): + raise ImportError( + "requests is required for URL loading. Install with: pip install requests" + ) + import requests + + response = requests.get(source, timeout=30) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + + # Load from base64 data URL + if source.startswith("data:image"): + # Format: data:image/png;base64, + header, encoded = source.split(",", 1) + image_data = base64.b64decode(encoded) + return Image.open(BytesIO(image_data)).convert("RGB") + + # Load from local file path + return Image.open(source).convert("RGB") + + +def process_images_for_vlm( + images: List[Any], + processor: Any, + text: str = "", + model_type: Optional[str] = None, +) -> Dict[str, Any]: + """ + Process images for VLM input using a HuggingFace processor. + + Args: + images: List of image sources (URLs, paths, base64, or PIL Images) + processor: HuggingFace processor (e.g., Qwen2VLProcessor, LlavaProcessor) + text: Text prompt to process alongside images + model_type: Optional model type for model-specific processing + + Returns: + Dictionary containing: + - pixel_values: numpy array of processed images + - image_grid_thw: (optional) grid sizes for dynamic resolution models + - image_sizes: (optional) original image sizes + - input_ids: (optional) tokenized input with image tokens + """ + if not images: + return {} + + # Load all images + pil_images = [] + image_sizes = [] + for img_src in images: + try: + img = load_image(img_src) + pil_images.append(img) + image_sizes.append((img.height, img.width)) + except Exception as e: + logger.warning(f"Failed to load image: {e}") + continue + + if not pil_images: + logger.warning("No images were successfully loaded") + return {} + + # Process with HuggingFace processor + try: + inputs = processor( + text=text, + images=pil_images, + return_tensors="np", # Use numpy for MLX compatibility + ) + except Exception as e: + logger.error(f"Image processing failed: {e}") + return {} + + result = {"image_sizes": image_sizes} + + # Extract pixel_values + if hasattr(inputs, "pixel_values"): + pixel_values = inputs.pixel_values + if hasattr(pixel_values, "numpy"): + pixel_values = pixel_values.numpy() + result["pixel_values"] = pixel_values + elif "pixel_values" in inputs: + pixel_values = inputs["pixel_values"] + if hasattr(pixel_values, "numpy"): + pixel_values = pixel_values.numpy() + result["pixel_values"] = pixel_values + + # Extract image_grid_thw (for Qwen-VL style models) + if hasattr(inputs, "image_grid_thw"): + grid_thw = inputs.image_grid_thw + if hasattr(grid_thw, "numpy"): + grid_thw = grid_thw.numpy() + result["image_grid_thw"] = grid_thw + elif "image_grid_thw" in inputs: + grid_thw = inputs["image_grid_thw"] + if hasattr(grid_thw, "numpy"): + grid_thw = grid_thw.numpy() + result["image_grid_thw"] = grid_thw + + # Extract input_ids if available (some processors expand image tokens) + if hasattr(inputs, "input_ids"): + input_ids = inputs.input_ids + if hasattr(input_ids, "numpy"): + input_ids = input_ids.numpy() + result["input_ids"] = input_ids.flatten().tolist() + elif "input_ids" in inputs: + input_ids = inputs["input_ids"] + if hasattr(input_ids, "numpy"): + input_ids = input_ids.numpy() + result["input_ids"] = input_ids.flatten().tolist() + + logger.debug( + f"Processed {len(pil_images)} images, " + f"pixel_values shape: {result.get('pixel_values', np.array([])).shape}" + ) + + return result + + +def create_vlm_inputs_from_request( + image_urls: List[str], + processor: Any, + text: str = "", + model_type: Optional[str] = None, +) -> Optional["VLMInputs"]: + """ + Create VLMInputs object from image URLs and processor. + + This is a convenience function that combines image processing + and VLMInputs creation. + + Args: + image_urls: List of image URLs or paths + processor: HuggingFace processor + text: Text prompt + model_type: Optional model type + + Returns: + VLMInputs object or None if processing fails + """ + from parallax.server.request import VLMInputs + + if not image_urls: + return None + + processed = process_images_for_vlm( + images=image_urls, + processor=processor, + text=text, + model_type=model_type, + ) + + if not processed or "pixel_values" not in processed: + return None + + return VLMInputs( + pixel_values=processed["pixel_values"], + image_grid_thw=processed.get("image_grid_thw"), + image_sizes=processed.get("image_sizes"), + images_processed=False, + ) + + +def get_image_token_count( + image_grid_thw: Optional[np.ndarray] = None, + image_size: Optional[Tuple[int, int]] = None, + patch_size: int = 14, + merge_size: int = 2, + temporal_patch_size: int = 2, +) -> int: + """ + Calculate the number of image tokens for a given image. + + Different VLM models have different token counting strategies: + - LLaVA: Fixed number based on image size / patch_size + - Qwen-VL: Dynamic based on image_grid_thw (temporal * height * width) + + Args: + image_grid_thw: Grid sizes (temporal, height, width) for Qwen-VL style + image_size: Image size (height, width) for LLaVA style + patch_size: Size of each image patch + merge_size: Merge factor for Qwen-VL (reduces tokens by merge_size^2) + temporal_patch_size: Temporal patch size for video + + Returns: + Number of image tokens + """ + if image_grid_thw is not None: + # Qwen-VL style: t * h * w tokens + if isinstance(image_grid_thw, np.ndarray): + t, h, w = image_grid_thw.flatten()[:3] + else: + t, h, w = image_grid_thw + # After merge: (t * h * w) / (merge_size^2) + return int((t * h * w) // (merge_size**2)) + + if image_size is not None: + # LLaVA style: (H / patch_size) * (W / patch_size) + h, w = image_size + return (h // patch_size) * (w // patch_size) + + # Default fallback + return 576 # Common default for 336x336 image with 14x14 patches diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index a8f21dbb..f01a7bc3 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Dict, List, Optional, Set +from parallax.utils.config_utils import get_config_value, is_vlm_model + logger = logging.getLogger(__name__) @@ -13,22 +15,42 @@ def should_include_weight_key( is_first_shard: bool, is_last_shard: bool, tie_word_embeddings: bool = False, + is_vlm: bool = False, ) -> bool: - if is_first_shard and "embed_tokens" in key and key.startswith("model."): + # Embeddings on first shard + # Handles: model.embed_tokens, model.language_model.embed_tokens, language_model.model.embed_tokens + if is_first_shard and "embed_tokens" in key: return True + if is_first_shard and is_vlm: + vlm_prefixes = [ + "vision_tower", + "vision_model", + "visual", + "multi_modal_projector", + "mm_projector", + ] + for prefix in vlm_prefixes: + if key.startswith(prefix) or key.startswith(f"model.{prefix}"): + return True + if is_last_shard: - if "model.norm" in key or "lm_head" in key: + if ("lm_head" in key) or ( + (".norm." in key or key.endswith(".norm.weight")) and "layers" not in key + ): return True - if tie_word_embeddings and "embed" in key and key.startswith("model.embed_tokens"): + if tie_word_embeddings and "embed_tokens" in key: return True if "layers." in key: parts = key.split(".") for i, part in enumerate(parts): if part == "layers" and i + 1 < len(parts): - layer_idx = int(parts[i + 1]) - return start_layer <= layer_idx < end_layer + try: + layer_idx = int(parts[i + 1]) + return start_layer <= layer_idx < end_layer + except ValueError: + continue return False @@ -41,6 +63,7 @@ def filter_weight_files_by_layer_range_for_load( is_first_shard: bool, is_last_shard: bool, config: Optional[Dict] = None, + is_vlm: bool = False, ) -> List[str]: index_file = model_path / "model.safetensors.index.json" @@ -58,13 +81,13 @@ def filter_weight_files_by_layer_range_for_load( tie_word_embeddings = False if config: - tie_word_embeddings = config.get("tie_word_embeddings", False) + tie_word_embeddings = get_config_value(config, "tie_word_embeddings", False) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) - tie_word_embeddings = cfg.get("tie_word_embeddings", False) + tie_word_embeddings = get_config_value(cfg, "tie_word_embeddings", False) needed_files: Set[str] = set() @@ -78,6 +101,7 @@ def filter_weight_files_by_layer_range_for_load( is_first_shard=is_first_shard, is_last_shard=is_last_shard, tie_word_embeddings=tie_word_embeddings, + is_vlm=is_vlm, ): needed_files.add(filename) @@ -95,9 +119,18 @@ def filter_weight_files_by_layer_range_for_load( logger.debug( f"Filtered weight files from {len(weight_files)} to {len(filtered_files)} " - f"for layers [{start_layer}, {end_layer})" + f"for layers [{start_layer}, {end_layer}), is_vlm={is_vlm}, is_first_shard={is_first_shard}" ) + # If filtering resulted in no files but we had input files, + # fall back to original files (file naming may differ from index) + if not filtered_files and weight_files: + logger.debug( + f"Filtering resulted in no files, falling back to all {len(weight_files)} weight files. " + f"needed_files={needed_files}, input_files={[Path(wf).name for wf in weight_files]}" + ) + return weight_files + return filtered_files @@ -110,16 +143,19 @@ def determine_needed_weight_files_for_download( is_first_shard = start_layer == 0 is_last_shard = False + is_vlm_flag = False if config: - num_hidden_layers = config.get("num_hidden_layers", 0) + num_hidden_layers = get_config_value(config, "num_hidden_layers", 0) is_last_shard = end_layer >= num_hidden_layers + is_vlm_flag = is_vlm_model(config) else: config_file = model_path / "config.json" if config_file.exists(): with open(config_file, "r") as f: cfg = json.load(f) - num_hidden_layers = cfg.get("num_hidden_layers", 0) + num_hidden_layers = get_config_value(cfg, "num_hidden_layers", 0) is_last_shard = end_layer >= num_hidden_layers + is_vlm_flag = is_vlm_model(cfg) index_file = model_path / "model.safetensors.index.json" @@ -147,9 +183,9 @@ def determine_needed_weight_files_for_download( logger.debug("weight_map is empty in index file") return [] - tie_word_embeddings = False - if config: - tie_word_embeddings = config.get("tie_word_embeddings", False) + tie_word_embeddings = ( + get_config_value(config, "tie_word_embeddings", False) if config else False + ) needed_files: Set[str] = set() @@ -163,6 +199,7 @@ def determine_needed_weight_files_for_download( is_first_shard=is_first_shard, is_last_shard=is_last_shard, tie_word_embeddings=tie_word_embeddings, + is_vlm=is_vlm_flag, ): needed_files.add(filename) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py index d7ea47b3..7b23f87c 100644 --- a/src/parallax/vllm/batch_info.py +++ b/src/parallax/vllm/batch_info.py @@ -141,6 +141,7 @@ def _build_vllm_request( model_runner: Any, *, include_outputs: bool, + multimodal_params: Optional[Dict] = None, ) -> VLLMRequest: block_hasher = getattr(model_runner, "request_block_hasher", None) @@ -171,6 +172,7 @@ def _build_vllm_request( arrival_time=getattr(req, "arrival_time", 0.0), block_hasher=block_hasher, lora_request=lora_req, + multi_modal_data=multimodal_params, ) if include_outputs: output_ids = getattr(req, "output_ids", None) or [] @@ -204,8 +206,15 @@ def form_vllm_batch_prefill( for req in batched_requests: sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - - vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=False) + multimodal_params = getattr(req, "multimodal_params", None) + + vllm_req = _build_vllm_request( + req, + sampling_params, + model_runner, + include_outputs=False, + multimodal_params=multimodal_params, + ) created_vllm_requests.append(vllm_req) computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) @@ -328,7 +337,14 @@ def form_vllm_batch_decode( new_token_ids.append([]) sampling_params = transform_sampling_params_to_vllm(req.sampling_params) - vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) + multimodal_params = getattr(req, "multimodal_params", None) + vllm_req = _build_vllm_request( + req, + sampling_params, + model_runner, + include_outputs=True, + multimodal_params=multimodal_params, + ) prompt_ids = getattr(req, "input_ids", None) or [] # For decode stage, computed_token_count should be the total number of tokens diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py index def32deb..e210cfcc 100644 --- a/src/parallax/vllm/model_runner.py +++ b/src/parallax/vllm/model_runner.py @@ -408,10 +408,13 @@ def initialize_vllm_model_runner( # Reuse the generic monkey patch used by sglang implementation to reduce # local weight file reads when loading a partial layer shard. try: - set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers) + from parallax.utils.config_utils import is_vlm_model + + is_vlm = is_vlm_model(config) + set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers, is_vlm=is_vlm) apply_weight_loader_filter_patch() logger.debug( - f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer})" + f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer}), is_vlm={is_vlm}" ) except Exception as e: logger.warning("Failed to apply weight loader filter patch for vLLM loading: %s", e) diff --git a/src/parallax_utils/logging_config.py b/src/parallax_utils/logging_config.py index 13a91937..ba1c1fd8 100644 --- a/src/parallax_utils/logging_config.py +++ b/src/parallax_utils/logging_config.py @@ -6,10 +6,12 @@ import threading from typing import Optional -__all__ = ["get_logger", "use_parallax_log_handler", "set_log_level"] +__all__ = ["get_logger", "use_parallax_log_handler", "set_log_level", "set_rank"] _init_lock = threading.Lock() _default_handler: logging.Handler | None = None +_current_rank: int = 0 # Default to rank 0 (will print logs) +_rank_filter_enabled: bool = False # Whether to filter logs by rank class _Ansi: @@ -138,3 +140,34 @@ def use_parallax_log_handler(for_root: bool = True): root = logging.getLogger() if _default_handler not in root.handlers: root.addHandler(_default_handler) + + +def set_rank(rank: int, enable_filter: bool = True): + """ + Set the current process rank for log filtering. + + When rank filtering is enabled, only rank 0 will print logs. + This is useful for multi-GPU (TP/DP) scenarios where you want + to avoid duplicate log messages from all processes. + + Args: + rank: The rank of the current process (0 for master). + enable_filter: If True, only rank 0 will print logs. + """ + global _current_rank, _rank_filter_enabled + _current_rank = rank + _rank_filter_enabled = enable_filter + + if enable_filter and rank != 0: + # Disable logging for non-zero ranks by setting to a high level + logging.getLogger().setLevel(logging.CRITICAL + 1) + + +def get_rank() -> int: + """Get the current process rank.""" + return _current_rank + + +def is_rank_zero() -> bool: + """Check if the current process is rank 0.""" + return _current_rank == 0 diff --git a/tests/test_tool_call.py b/tests/test_tool_call.py new file mode 100644 index 00000000..0c1b2759 --- /dev/null +++ b/tests/test_tool_call.py @@ -0,0 +1,516 @@ +""" +Tests for ToolCallState parsing in parallax. + +Tests that ToolCallState correctly detects, extracts, and formats +tool calls from model output segments. +""" + +import json +import unittest +from unittest.mock import MagicMock + +from parallax.utils.tokenizer_utils import ToolCallState + + +def make_tool_state( + tool_call_start: str, + tool_call_end: str, + tool_parser, + tools=None, + stream=False, +) -> ToolCallState: + """Helper to create a ToolCallState with the given parser config.""" + return ToolCallState( + has_tool_calling=True, + tool_call_start=tool_call_start, + tool_call_end=tool_call_end, + tool_parser=tool_parser, + tools=tools, + stream=stream, + ) + + +# ---- Parsers mimicking common model formats ---- + + +def json_tool_parser(tool_text: str, tools): + """Parser for JSON-formatted tool calls (e.g. Qwen, GLM).""" + parsed = json.loads(tool_text) + return { + "name": parsed["name"], + "arguments": parsed["arguments"], + } + + +def xml_tool_parser(tool_text: str, tools): + """Parser for XML-formatted tool calls (e.g. minimax style).""" + import re + + name_match = re.search(r'', tool_text) + if not name_match: + raise ValueError("No invoke name found") + name = name_match.group(1) + params = dict(re.findall(r'([^<]+)', tool_text)) + # Try to convert numeric values + for k, v in params.items(): + try: + params[k] = int(v) + except ValueError: + try: + params[k] = float(v) + except ValueError: + pass + return {"name": name, "arguments": params} + + +def multi_tool_parser(tool_text: str, tools): + """Parser that returns a list of tool calls (e.g. kimi-k2 style).""" + calls = [] + for part in tool_text.split("}{"): + part = part.strip() + if not part.startswith("{"): + part = "{" + part + if not part.endswith("}"): + part = part + "}" + parsed = json.loads(part) + calls.append( + { + "id": parsed.get("id", "call_0"), + "name": parsed["name"], + "arguments": parsed["arguments"], + } + ) + return calls + + +# ---- Sample tools definition ---- + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location.", + "parameters": { + "type": "object", + "required": ["location"], + "properties": { + "location": {"type": "string", "description": "The city name."}, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers.", + "parameters": { + "type": "object", + "required": ["a", "b"], + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + }, + }, + }, +] + + +class TestToolCallStateInit(unittest.TestCase): + """Test ToolCallState initialization via from_tokenizer.""" + + def test_from_tokenizer_with_tool_support(self): + """Tokenizer with full tool calling support should produce active state.""" + tokenizer = MagicMock() + tokenizer.has_tool_calling = True + tokenizer.tool_parser = json_tool_parser + tokenizer.tool_call_start = "" + tokenizer.tool_call_end = "" + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertTrue(state.has_tool_calling) + self.assertEqual(state.tool_call_start, "") + self.assertEqual(state.tool_call_end, "") + self.assertIs(state.tool_parser, json_tool_parser) + self.assertEqual(state.tools, SAMPLE_TOOLS) + + def test_from_tokenizer_without_tool_support(self): + """Tokenizer without tool calling attributes should produce inactive state.""" + tokenizer = MagicMock(spec=[]) # No attributes + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertFalse(state.has_tool_calling) + self.assertIsNone(state.tool_call_start) + self.assertIsNone(state.tool_call_end) + self.assertIsNone(state.tool_parser) + + def test_from_tokenizer_partial_support(self): + """Tokenizer with partial attributes (missing tool_call_end) should be inactive.""" + tokenizer = MagicMock() + tokenizer.has_tool_calling = True + tokenizer.tool_parser = json_tool_parser + tokenizer.tool_call_start = "" + tokenizer.tool_call_end = None # Missing end marker + + state = ToolCallState.from_tokenizer(tokenizer, SAMPLE_TOOLS, stream=False) + + self.assertFalse(state.has_tool_calling) + + +class TestToolCallExtraction(unittest.TestCase): + """Test extract_from_segment with various tool call formats.""" + + def test_json_tool_call(self): + """Test JSON-formatted tool call extraction.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual( + json.loads(tool_calls[0]["function"]["arguments"]), + {"location": "Beijing"}, + ) + self.assertEqual(tool_calls[0]["type"], "function") + self.assertIn("id", tool_calls[0]) + + def test_xml_tool_call(self): + """Test XML-formatted tool call extraction.""" + state = make_tool_state( + tool_call_start="<|tool_start|>", + tool_call_end="<|tool_end|>", + tool_parser=xml_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + "<|tool_start|>" + '' + '12345' + '67890' + "" + "<|tool_end|>" + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "multiply") + self.assertEqual( + json.loads(tool_calls[0]["function"]["arguments"]), + {"a": 12345, "b": 67890}, + ) + + def test_text_before_tool_call(self): + """Tool call preceded by regular text should preserve the text.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = 'Let me check the weather for you.{"name": "get_weather", "arguments": {"location": "Shanghai"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "Let me check the weather for you.") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + + def test_text_after_tool_call(self): + """Text after tool call end marker should be preserved.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "Tokyo"}} Done!' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, " Done!") + self.assertEqual(len(tool_calls), 1) + + def test_no_tool_call(self): + """Segment without tool call markers should pass through unchanged.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "The weather in Beijing is sunny today." + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, segment) + self.assertEqual(tool_calls, []) + + def test_inactive_state_passes_through(self): + """Inactive ToolCallState (no tool support) should pass through text unchanged.""" + state = ToolCallState( + has_tool_calling=False, + tool_call_start=None, + tool_call_end=None, + tool_parser=None, + tools=None, + stream=False, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, segment) + self.assertEqual(tool_calls, []) + + def test_empty_segment(self): + """Empty segment should return empty text and no tool calls.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + text, tool_calls = state.extract_from_segment("") + self.assertEqual(text, "") + self.assertEqual(tool_calls, []) + + +class TestToolCallStreaming(unittest.TestCase): + """Test streaming scenarios where tool calls arrive across multiple segments.""" + + def test_tool_call_split_across_segments(self): + """Tool call split across two segments should accumulate and parse correctly.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + # First segment: start marker + partial content + text1, calls1 = state.extract_from_segment('{"name": "get_weather",') + self.assertEqual(text1, "") + self.assertEqual(calls1, []) + self.assertTrue(state.in_tool_call) + + # Second segment: rest of content + end marker + text2, calls2 = state.extract_from_segment( + ' "arguments": {"location": "Beijing"}}' + ) + self.assertEqual(text2, "") + self.assertEqual(len(calls2), 1) + self.assertEqual(calls2[0]["function"]["name"], "get_weather") + self.assertFalse(state.in_tool_call) + + def test_tool_call_split_three_segments(self): + """Tool call split across three segments.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + text1, calls1 = state.extract_from_segment("") + self.assertEqual(calls1, []) + self.assertTrue(state.in_tool_call) + + text2, calls2 = state.extract_from_segment( + '{"name": "multiply", "arguments": {"a": 3, "b": 7}}' + ) + self.assertEqual(calls2, []) + + text3, calls3 = state.extract_from_segment("") + self.assertEqual(len(calls3), 1) + self.assertEqual(calls3[0]["function"]["name"], "multiply") + + def test_stream_mode_has_index(self): + """In stream mode, tool calls should have an 'index' field.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + stream=True, + ) + + segment = '{"name": "get_weather", "arguments": {"location": "NYC"}}' + _, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(len(tool_calls), 1) + self.assertIn("index", tool_calls[0]) + self.assertEqual(tool_calls[0]["index"], 0) + + def test_stream_mode_index_increments(self): + """In stream mode, tool call index should increment for each call.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + stream=True, + ) + + # First tool call + segment1 = ( + '{"name": "get_weather", "arguments": {"location": "NYC"}}' + ) + _, calls1 = state.extract_from_segment(segment1) + self.assertEqual(calls1[0]["index"], 0) + + # Second tool call + segment2 = '{"name": "multiply", "arguments": {"a": 2, "b": 3}}' + _, calls2 = state.extract_from_segment(segment2) + self.assertEqual(calls2[0]["index"], 1) + + +class TestToolCallMultiple(unittest.TestCase): + """Test multiple tool calls in a single segment.""" + + def test_two_tool_calls_in_one_segment(self): + """Two consecutive tool calls in one segment should both be extracted.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + '{"name": "multiply", "arguments": {"a": 3, "b": 5}}' + ) + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual(tool_calls[1]["function"]["name"], "multiply") + + def test_multi_return_parser(self): + """Parser that returns a list of tool calls should all be formatted.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=multi_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = '{"id": "call_1", "name": "get_weather", "arguments": {"location": "London"}}' + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "") + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") + self.assertEqual(tool_calls[0]["id"], "call_1") + + +class TestToolCallEdgeCases(unittest.TestCase): + """Test edge cases and error handling.""" + + def test_malformed_json_falls_back(self): + """Malformed JSON inside tool call markers should fall back to raw text.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "not valid json at all" + text, tool_calls = state.extract_from_segment(segment) + + self.assertEqual(text, "not valid json at all") + self.assertEqual(tool_calls, []) + + def test_made_tool_call_flag(self): + """made_tool_call flag should be set after a tool call is detected.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + self.assertFalse(state.made_tool_call) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "Beijing"}}' + ) + state.extract_from_segment(segment) + + self.assertTrue(state.made_tool_call) + + def test_made_tool_call_flag_not_set_on_parse_failure(self): + """made_tool_call should be False if parsing fails (no valid tool call was produced).""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = "invalid" + state.extract_from_segment(segment) + + self.assertFalse(state.made_tool_call) + + def test_string_arguments_serialized_to_json(self): + """Tool call arguments dict should be serialized to JSON string in output.""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "北京"}}' + ) + _, tool_calls = state.extract_from_segment(segment) + + # arguments should be a JSON string (not a dict) in the output + args_str = tool_calls[0]["function"]["arguments"] + self.assertIsInstance(args_str, str) + self.assertEqual(json.loads(args_str), {"location": "北京"}) + + def test_unicode_in_arguments(self): + """Unicode characters in arguments should be preserved (ensure_ascii=False).""" + state = make_tool_state( + tool_call_start="", + tool_call_end="", + tool_parser=json_tool_parser, + tools=SAMPLE_TOOLS, + ) + + segment = ( + '{"name": "get_weather", "arguments": {"location": "東京"}}' + ) + _, tool_calls = state.extract_from_segment(segment) + + args_str = tool_calls[0]["function"]["arguments"] + # ensure_ascii=False means the Chinese/Japanese chars should appear directly + self.assertIn("東京", args_str) + + +if __name__ == "__main__": + unittest.main()