Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4f96e46
init
yuhao-zh Feb 2, 2026
7d6f718
update
yuhao-zh Feb 2, 2026
4358906
update
yuhao-zh Feb 2, 2026
43cf2a6
Merge branch 'main' into feat/vlm-support
yuhao-zh Feb 2, 2026
104a8a6
tmp add
yuhao-zh Feb 2, 2026
8723841
11
yuhao-zh Feb 2, 2026
605577a
done sglang
yuhao-zh Feb 3, 2026
adec101
refactor sglang vlm
yuhao-zh Feb 3, 2026
9a3707a
tmp add mlx
yuhao-zh Feb 3, 2026
0a67dad
load success
yuhao-zh Feb 3, 2026
8f7ce00
run success on mlx
yuhao-zh Feb 4, 2026
96fbf6a
pre-commit
yuhao-zh Feb 4, 2026
8b2ed0b
add config get utils
yuhao-zh Feb 4, 2026
4a68fd3
refactor baseexecutor
yuhao-zh Feb 4, 2026
c5f4361
update
yuhao-zh Feb 5, 2026
ba5d456
update sglang version
yuhao-zh Feb 5, 2026
9afa829
update config layers read
yuhao-zh Feb 5, 2026
ab53736
specfic for kimi
yuhao-zh Feb 5, 2026
28001b8
only print rank0 log and fix kimi token expand
yuhao-zh Feb 5, 2026
be6d1f1
fix weight fiter
yuhao-zh Feb 5, 2026
ec58578
Merge branch 'main' into feat/vlm-support
yuhao-zh Feb 6, 2026
4e2b1d9
update mlx-lm
yuhao-zh Feb 6, 2026
fbbde63
add tool test
yuhao-zh Feb 6, 2026
a410ed4
rebase pip install
yuhao-zh Feb 6, 2026
f353ff7
update pyprojection
yuhao-zh Feb 6, 2026
2dea44b
fix pre commit
yuhao-zh Feb 6, 2026
ed80339
update
yuhao-zh Feb 7, 2026
44c2234
check kimi2.5 sampling params
yuhao-zh Feb 9, 2026
967b4a0
moiify create time
yuhao-zh Feb 9, 2026
7460d89
modify maxtokens limit
yuhao-zh Feb 9, 2026
2b1eece
Merge branch 'main' into feat/vlm-support
yuhao-zh Feb 9, 2026
726dfb8
fix pre-commit
yuhao-zh Feb 9, 2026
fc2c349
update
yuhao-zh Feb 9, 2026
6ae3863
modify tokenizer
yuhao-zh Feb 9, 2026
924f256
update
yuhao-zh Feb 9, 2026
ed665f1
pre-commit
yuhao-zh Feb 10, 2026
be1f910
update
yuhao-zh Feb 11, 2026
e5e1b12
update
yuhao-zh Feb 11, 2026
69f3403
add metal detect
yuhao-zh Feb 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/user_guide/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Note: If you are using DGX Spark, please refer to the Docker installation sectio
```sh
git clone https://github.com/GradientHQ/parallax.git
cd parallax
pip install -e '.[gpu]'
pip install -e ".[gpu]" && pip install mlx-lm==0.30.6 --no-deps
```

#### For macOS (Apple silicon):
Expand All @@ -34,14 +34,14 @@ cd parallax
python3 -m venv ./venv
source ./venv/bin/activate

pip install -e '.[mac]'
pip install -e ".[mac]"
```

Next time to re-activate this virtual environment, run ```source ./venv/bin/activate```.

#### Extra step for development:
```sh
pip install -e '.[dev]'
pip install -e ".[dev]"
```

### Windows Application
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"dijkstar==2.6.0",
"lattica==1.0.21",
"orjson",

]

[project.scripts]
Expand All @@ -47,12 +48,16 @@ mac = [
"torch==2.8.0",
"mlx-lm==0.30.6",
"mlx==0.30.4",
"mlx-vlm==0.3.10",
"torchvision==0.23.0"
]

gpu = [
"sglang[all]==0.5.7",
"sglang[all] @ git+https://github.com/sgl-project/sglang.git@9409c43593f2d6d64595981abf216a15752b0875#subdirectory=python",
"mlx-lm==0.28.4",
"mlx[cpu]==0.30.0",
"mlx[cpu]==0.30.4",
# due to transformers version conflict, we need to install mlx-lm separately
# pip install mlx-lm==0.30.6 --no-deps
]

vllm = [
Expand Down
9 changes: 6 additions & 3 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from parallax.server.executor.factory import run_executor_process, stop_executor_process
from parallax.server.http_server import launch_http_server, stop_http_server
from parallax.server.server_args import parse_args
from parallax.utils.config_utils import get_config_value
from parallax.utils.shared_state import SharedState
from parallax.utils.utils import fetch_model_from_hf, initialize_nccl_port
from parallax_utils.ascii_anime import display_parallax_join
Expand Down Expand Up @@ -120,23 +121,25 @@ def _wait_executors_check_layer_change(shared_state: SharedState, executor_subpr
check_latest_release()

config = fetch_model_from_hf(args.model_path, local_files_only=args.use_hfcache)
num_layers = get_config_value(config, "num_hidden_layers")

if args.start_layer is None:
args.start_layer = 0
if args.end_layer is None:
args.end_layer = config.get("num_hidden_layers")
args.end_layer = num_layers

# only launch http server on head node
if args.start_layer == 0:
http_server_process = launch_http_server(args)
# Launch P2P server as subprocess
if not (args.start_layer == 0 and args.end_layer == config.get("num_hidden_layers")):
if not (args.start_layer == 0 and args.end_layer == num_layers):
p2p_server_process = launch_p2p_server_process(
initial_peers=args.initial_peers,
scheduler_addr=args.scheduler_addr,
relay_servers=args.relay_servers,
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=config.get("num_hidden_layers"),
hidden_layers=num_layers,
tp_size=args.tp_size,
dp_size=args.dp_size,
tcp_port=args.tcp_port,
Expand Down
203 changes: 203 additions & 0 deletions src/parallax/models/kimi_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Defines the KimiVL model for Parallax.

KimiVL uses a DeepSeek-V3 based language model with MoE and a MoonViT vision encoder.
This module reuses components from mlx-vlm and adds PagedAttention support
for distributed inference.
"""

from typing import Any, List, Optional

import mlx.core as mx
from mlx_lm.models.base import scaled_dot_product_attention

# Import from mlx-vlm kimi_vl language module
from mlx_vlm.models.kimi_vl.language import DeepseekV3Attention as MLXKimiVLAttention
from mlx_vlm.models.kimi_vl.language import (
DeepseekV3DecoderLayer as MLXKimiVLDecoderLayer,
)

from parallax.metal.paged_attention.kernel import paged_attention, reshape_and_cache
from parallax.server.cache.base import BaseCache
from parallax.utils.prefix_cache_utils import compute_attention_with_prefix_cache
from parallax_utils.logging_config import get_logger

logger = get_logger(__name__)


class ParallaxKimiVLAttention(MLXKimiVLAttention):
"""KimiVL (DeepSeek-V3) Attention with PagedAttention support for Parallax.

This extends the MLX-VLM KimiVL attention (DeepseekV3Attention) with:
- Paged KV cache support for efficient memory management
- Block-table based attention for decode phase
- Prefix cache support for prefill phase
"""

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[BaseCache] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
block_tables: Optional[mx.array] = None,
context_lengths: Optional[mx.array] = None,
slot_mapping: Optional[mx.array] = None,
prefix_lens: Optional[mx.array] = None,
**kwargs,
) -> mx.array:
batch, target_len, _ = x.shape

# Q projection (with optional LoRA)
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))

q = q.reshape(batch, target_len, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)

# KV projection (with MQA compression)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(batch, target_len, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))

kv = kv.reshape(batch, target_len, self.num_heads, -1)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_nope = k_nope.transpose(0, 2, 1, 3)

# Get KV cache
key_cache_global, value_cache_global = cache.get_cache()

# Compute RoPE offsets
if target_len == 1:
# Decode phase: position is context_length - 1
current_pos = context_lengths - 1
elif prefix_lens is not None:
# Prefill phase with prefix cache
current_pos = prefix_lens
else:
# Prefill phase without prefix cache
current_pos = 0

# Apply RoPE
q_pe = self.rope(q_pe, offset=current_pos)
k_pe = self.rope(k_pe, offset=current_pos)

k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)

# Cache update with PagedAttention
block_size = key_cache_global.shape[3]

reshape_and_cache(
keys.transpose(0, 2, 1, 3),
values,
key_cache_global,
value_cache_global,
block_tables,
context_lengths,
block_size,
slot_mapping=slot_mapping,
)

if target_len == 1:
# Decode phase: Use Paged Attention
output = paged_attention(
queries,
key_cache_global,
value_cache_global,
block_tables,
context_lengths,
block_size,
self.scale,
self.num_heads,
v_head_dim=values.shape[-1],
)
output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1)
else:
# Prefill phase
has_prefix_cache = prefix_lens is not None and bool(mx.any(prefix_lens > 0))

if has_prefix_cache:
k_new = keys
v_new = values.transpose(0, 2, 1, 3)
output = compute_attention_with_prefix_cache(
queries,
k_new,
v_new,
cache,
block_tables,
prefix_lens,
target_len,
self.scale,
self.num_heads,
mask=mask,
)
else:
# Standard self-attention
if mask is not None:
mask = mx.array(mask, dtype=queries.dtype)

output = scaled_dot_product_attention(
queries,
keys,
values.transpose(0, 2, 1, 3),
scale=self.scale,
mask=mask,
cache=None,
)
output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1)

return self.o_proj(output)


class ParallaxKimiVLBlock(MLXKimiVLDecoderLayer):
"""KimiVL Transformer block with PagedAttention support.

Extends the MLX-VLM KimiVL decoder layer to use ParallaxKimiVLAttention
and pass through paged attention arguments.
"""

def __init__(self, args, layer_idx: int, local_layer_idx: int):
super().__init__(args, layer_idx=layer_idx)
# Replace attention with Parallax version
self.self_attn = ParallaxKimiVLAttention(args)
self.layer_idx = layer_idx
self.local_layer_idx = local_layer_idx

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[List[Any]] = None,
lengths: Optional[mx.array] = None,
block_tables: Optional[mx.array] = None,
context_lengths: Optional[mx.array] = None,
slot_mapping: Optional[mx.array] = None,
**kwargs,
):
r = self.self_attn(
self.input_layernorm(x),
mask,
cache[self.local_layer_idx],
block_tables=block_tables,
context_lengths=context_lengths,
slot_mapping=slot_mapping,
**kwargs,
)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out

@classmethod
def get_architecture(cls):
"""Get the architecture name for the block."""
return "KimiVLForConditionalGeneration"


EntryClass = ParallaxKimiVLBlock
Loading