Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel, Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
Expand Down
1 change: 0 additions & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def _init_to_get_rotary(self, default_base=10000):
/ rope_scaling_factor
)
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
Expand Down
6 changes: 6 additions & 0 deletions lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(
image_file=image_data,
factor=self.processor.patch_size * self.processor.merge_size,
min_pixels=self.processor.min_pixels,
max_pixels=self.processor.max_pixels,
)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
Expand Down
6 changes: 6 additions & 0 deletions lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def encode(self, images: List[ImageItem]):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = resize_image(
image_file=image_data,
factor=self.processor.patch_size * self.processor.merge_size,
min_pixels=self.processor.min_pixels,
max_pixels=self.processor.max_pixels,
)
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
Expand Down
7 changes: 7 additions & 0 deletions lightllm/models/qwen2_vl/vision_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def __init__(
self.disable_grouping = disable_grouping
self.interpolation = interpolation
self.data_format = ChannelDimension.FIRST
if isinstance(self.size, dict):
shortest = self.size.get("shortest_edge", None)
longest = self.size.get("longest_edge", None)
if shortest is not None:
self.min_pixels = shortest
if longest is not None:
self.max_pixels = longest
self._fused_cache = {} # key: (do_norm, do_rescale, rescale_factor, device)

def _get_fused_mean_std(
Expand Down
56 changes: 56 additions & 0 deletions lightllm/models/qwen3_vl/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import numpy as np
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.basemodel.infer_struct import InferStateInfo


class Qwen3VLInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.input_ids = None
self.deepstack_features = []
self.deepstack_end_layer = None
self.img_start_token_ids = []
self.img_token_lens = []
self.img_start_locs = []

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
Comment on lines +21 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for apply_interleaved_mrope is inaccurate. The argument is freqs with a shape of (3, seq_len, head_dim // 2), but the docstring refers to x with a bs (batch size) dimension, which is not present. The return value shape is also (seq_len, head_dim // 2). Please update the docstring for clarity.

Suggested change
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
args:
freqs: (3, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
freqs_t: (seq_len, head_dim // 2)

"""
freqs_t = freqs[0] # just overwrite the first dimension T
for dim, offset in enumerate((1, 2), start=1): # H, W
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t

def init_some_extra_state(self, model, input_ids: torch.Tensor):
rope_scaling = model.config.get("rope_scaling", {})
self.mrope_section = rope_scaling.get("mrope_section", None)
InferStateInfo.init_some_extra_state(self, model, input_ids)
pos = self.position_ids[None, :].expand(3, -1)
cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2]
cos_H = torch.index_select(model._cos_cached, 0, pos[1])
cos_W = torch.index_select(model._cos_cached, 0, pos[2])
sin_T = torch.index_select(model._sin_cached, 0, pos[0])
sin_H = torch.index_select(model._sin_cached, 0, pos[1])
sin_W = torch.index_select(model._sin_cached, 0, pos[2])
cos_half = self.apply_interleaved_mrope(
torch.stack([cos_T, cos_H, cos_W], dim=0), self.mrope_section
) # [L, d/2]
sin_half = self.apply_interleaved_mrope(
torch.stack([sin_T, sin_H, sin_W], dim=0), self.mrope_section
) # [L, d/2]

self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d]
self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous()
if self.is_prefill:
pos = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line pos = None appears to be leftover debugging code, as the pos variable is not used after this assignment. It should be removed to improve code clarity.

return
93 changes: 93 additions & 0 deletions lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
import torch.distributed as dist

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo

from lightllm.server.embed_cache.utils import (
bytes2tensor,
read_shm,
get_shm_name_embed,
)
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce

from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer


class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
return

def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
img_weight = []
img_start_loc = 0

infer_state.input_ids = input_ids
infer_state.img_start_token_ids = []
infer_state.img_token_lens = []
infer_state.img_start_locs = []

device = layer_weight.wte_weight_.device
dtype = layer_weight.wte_weight_.dtype
hidden_size = layer_weight.wte_weight_.shape[1]

infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)

for batch_id, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
# skip the same image
if img["token_id"] in infer_state.img_start_token_ids or img["_prefill_"] is False:
continue

# all_img_embed_df的shape是
# image_embed(token_num, hidden_dim) + deepstack(token_num*layer_num, hidden_dim)
all_img_embed_df = bytes2tensor(read_shm(get_shm_name_embed(img["uuid"])))
per_image_deepstack = []

# 计算deepstack的层数
deepstack_layer_num = all_img_embed_df.shape[0] // img["token_num"] - 1
img_weight.append(all_img_embed_df[: img["token_num"]].cuda())

for layer in range(deepstack_layer_num):
start = img["token_num"] * (layer + 1)
end = img["token_num"] * (layer + 2)
per_image_deepstack.append(all_img_embed_df[start:end])

infer_state.deepstack_features.append(per_image_deepstack)
infer_state.img_start_token_ids.append(img["token_id"])
infer_state.img_token_lens.append(img["token_num"])
infer_state.img_start_locs.append(img_start_loc)
img_start_loc += img["token_num"]
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)

if len(img_weight) > 0:
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
else:
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
assert img_weight.shape[1] == hidden_size, (
f"Dimension mismatch: text weight dimension is {hidden_size}, "
f"but image weight dimension is {img_weight.shape[1]}"
)
# each tp will fill the img embeds, should divide by world_size
img_weight = img_weight / self.tp_world_size_
img_start_token_ids = torch.Tensor(infer_state.img_start_token_ids).to(device=device, dtype=torch.long)
img_token_lens = torch.Tensor(infer_state.img_token_lens).to(device=device, dtype=torch.long)
img_start_locs = torch.Tensor(infer_state.img_start_locs).to(device=device, dtype=torch.long)

multimodal_emb(
out,
input_ids,
layer_weight.wte_weight_,
img_weight,
img_token_lens,
img_start_token_ids,
img_start_locs,
self.vob_start_id_,
self.vob_end_id_,
)
if self.tp_world_size_ > 1:
all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False)
return out
55 changes: 55 additions & 0 deletions lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from functools import partial
from typing import Tuple
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from lightllm.distributed import all_reduce
from lightllm.utils.dist_utils import get_global_world_size
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features


class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer):
def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
axis_map = []
for i, n in enumerate(self.mrope_section * 2):
axis_map += [i % 3] * n
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")

def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
apply_deepstack_features(
input_embeddings=input_embdings,
infer_state=infer_state,
layer_num=self.layer_num_,
)
return input_embdings
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
from lightllm.common.basemodel import PreAndPostLayerWeight


class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.language_model.embed_tokens.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :])
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
if tie_word_embeddings:
self.lm_head_weight_ = self.wte_weight_
if "lm_head.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
if "model.language_model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"])

return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.wte_weight_,
self.lm_head_weight_,
self.final_norm_weight_,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight


class Qwen3VLTransformerLayerWeight(Qwen3TransformerLayerWeight):
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)

def _init_weight_names(self):
super()._init_weight_names()
self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight"
self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight"
self._q_bias_name = None
self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight"
self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight"
self._k_bias_name = None
self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight"
self._v_bias_name = None
self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight"
self._o_bias_name = None
self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight"
self._att_norm_bias_name = None
self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight"
self._ffn_norm_bias_name = None

self._gate_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.gate_proj.weight"
self._gate_bias_name = None
self._up_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.up_proj.weight"
self._up_bias_name = None
self._down_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.down_proj.weight"
self._down_bias_name = None
Loading