-
Notifications
You must be signed in to change notification settings - Fork 288
Add qwen3 vl #1095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
SangChengC
wants to merge
14
commits into
main
Choose a base branch
from
add-qwen3-vl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add qwen3 vl #1095
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
defa581
[add]add whisper sdpa
SangChengC 87c15dc
[add]add qwen3-vl-moe support
c318d72
fix1103
2ebdb58
add qwen3-vl support
1588ff3
1203
cd9c7ee
Merge branch 'main' into add-qwen3-vl
3ee963e
1204
0da89eb
1210
02486eb
1210
f6c5d64
Merge branch 'main' into add-qwen3-vl
1902799
1210
ebd5f7c
1210
ae29f70
1210
544f625
1210
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||
|
|
||
|
|
||
| class Qwen3VLInferStateInfo(LlamaInferStateInfo): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.input_ids = None | ||
| self.deepstack_features = [] | ||
| self.deepstack_end_layer = None | ||
| self.img_start_token_ids = [] | ||
| self.img_token_lens = [] | ||
| self.img_start_locs = [] | ||
|
|
||
| def apply_interleaved_mrope(self, freqs, mrope_section): | ||
| """Apply interleaved MRoPE to 3D rotary embeddings. | ||
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to | ||
| interleaved [THTHWHTHW...TT], preserving frequency continuity. | ||
| args: | ||
| x: (3, bs, seq_len, head_dim // 2) | ||
| mrope_section: (3,) | ||
| returns: | ||
| x_t: (bs, seq_len, head_dim // 2) | ||
| """ | ||
| freqs_t = freqs[0] # just overwrite the first dimension T | ||
| for dim, offset in enumerate((1, 2), start=1): # H, W | ||
| length = mrope_section[dim] * 3 | ||
| idx = slice(offset, length, 3) | ||
| freqs_t[..., idx] = freqs[dim, ..., idx] | ||
| return freqs_t | ||
|
|
||
| def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||
| rope_scaling = model.config.get("rope_scaling", {}) | ||
| self.mrope_section = rope_scaling.get("mrope_section", None) | ||
| InferStateInfo.init_some_extra_state(self, model, input_ids) | ||
| pos = self.position_ids[None, :].expand(3, -1) | ||
| cos_T = torch.index_select(model._cos_cached, 0, pos[0]) # [L, d/2] | ||
| cos_H = torch.index_select(model._cos_cached, 0, pos[1]) | ||
| cos_W = torch.index_select(model._cos_cached, 0, pos[2]) | ||
| sin_T = torch.index_select(model._sin_cached, 0, pos[0]) | ||
| sin_H = torch.index_select(model._sin_cached, 0, pos[1]) | ||
| sin_W = torch.index_select(model._sin_cached, 0, pos[2]) | ||
| cos_half = self.apply_interleaved_mrope( | ||
| torch.stack([cos_T, cos_H, cos_W], dim=0), self.mrope_section | ||
| ) # [L, d/2] | ||
| sin_half = self.apply_interleaved_mrope( | ||
| torch.stack([sin_T, sin_H, sin_W], dim=0), self.mrope_section | ||
| ) # [L, d/2] | ||
|
|
||
| self.position_cos = torch.cat([cos_half, cos_half], dim=-1).contiguous() # [L, d] | ||
| self.position_sin = torch.cat([sin_half, sin_half], dim=-1).contiguous() | ||
| if self.is_prefill: | ||
| pos = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo | ||
|
|
||
| from lightllm.server.embed_cache.utils import ( | ||
| bytes2tensor, | ||
| read_shm, | ||
| get_shm_name_embed, | ||
| ) | ||
| from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb | ||
| from lightllm.distributed.communication_op import all_reduce | ||
|
|
||
| from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer | ||
|
|
||
|
|
||
| class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer): | ||
| def __init__(self, network_config, mode): | ||
| super().__init__(network_config, mode) | ||
| return | ||
|
|
||
| def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): | ||
| img_weight = [] | ||
| img_start_loc = 0 | ||
|
|
||
| infer_state.input_ids = input_ids | ||
| infer_state.img_start_token_ids = [] | ||
| infer_state.img_token_lens = [] | ||
| infer_state.img_start_locs = [] | ||
|
|
||
| device = layer_weight.wte_weight_.device | ||
| dtype = layer_weight.wte_weight_.dtype | ||
| hidden_size = layer_weight.wte_weight_.shape[1] | ||
|
|
||
| infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids) | ||
|
|
||
| for batch_id, p in enumerate(infer_state.multimodal_params): | ||
| for img in p["images"] + p["audios"]: | ||
| # skip the same image | ||
| if img["token_id"] in infer_state.img_start_token_ids or img["_prefill_"] is False: | ||
| continue | ||
|
|
||
| # all_img_embed_df的shape是 | ||
| # image_embed(token_num, hidden_dim) + deepstack(token_num*layer_num, hidden_dim) | ||
| all_img_embed_df = bytes2tensor(read_shm(get_shm_name_embed(img["uuid"]))) | ||
| per_image_deepstack = [] | ||
|
|
||
| # 计算deepstack的层数 | ||
| deepstack_layer_num = all_img_embed_df.shape[0] // img["token_num"] - 1 | ||
| img_weight.append(all_img_embed_df[: img["token_num"]].cuda()) | ||
|
|
||
| for layer in range(deepstack_layer_num): | ||
| start = img["token_num"] * (layer + 1) | ||
| end = img["token_num"] * (layer + 2) | ||
| per_image_deepstack.append(all_img_embed_df[start:end]) | ||
|
|
||
| infer_state.deepstack_features.append(per_image_deepstack) | ||
| infer_state.img_start_token_ids.append(img["token_id"]) | ||
| infer_state.img_token_lens.append(img["token_num"]) | ||
| infer_state.img_start_locs.append(img_start_loc) | ||
| img_start_loc += img["token_num"] | ||
| out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) | ||
|
|
||
| if len(img_weight) > 0: | ||
| img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) | ||
| else: | ||
| img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype) | ||
| assert img_weight.shape[1] == hidden_size, ( | ||
| f"Dimension mismatch: text weight dimension is {hidden_size}, " | ||
| f"but image weight dimension is {img_weight.shape[1]}" | ||
| ) | ||
| # each tp will fill the img embeds, should divide by world_size | ||
| img_weight = img_weight / self.tp_world_size_ | ||
| img_start_token_ids = torch.Tensor(infer_state.img_start_token_ids).to(device=device, dtype=torch.long) | ||
| img_token_lens = torch.Tensor(infer_state.img_token_lens).to(device=device, dtype=torch.long) | ||
| img_start_locs = torch.Tensor(infer_state.img_start_locs).to(device=device, dtype=torch.long) | ||
|
|
||
| multimodal_emb( | ||
| out, | ||
| input_ids, | ||
| layer_weight.wte_weight_, | ||
| img_weight, | ||
| img_token_lens, | ||
| img_start_token_ids, | ||
| img_start_locs, | ||
| self.vob_start_id_, | ||
| self.vob_end_id_, | ||
| ) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) | ||
| return out |
55 changes: 55 additions & 0 deletions
55
lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import torch | ||
| import torch.functional as F | ||
| import torch.distributed as dist | ||
| import numpy as np | ||
| from functools import partial | ||
| from typing import Tuple | ||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||
| from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton | ||
| from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer | ||
| from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd | ||
| from lightllm.distributed import all_reduce | ||
| from lightllm.utils.dist_utils import get_global_world_size | ||
| from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features | ||
|
|
||
|
|
||
| class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer): | ||
| def __init__(self, layer_num, network_config, mode=[]): | ||
| super().__init__(layer_num, network_config, mode) | ||
| self.mrope_section = network_config["rope_scaling"]["mrope_section"] | ||
| axis_map = [] | ||
| for i, n in enumerate(self.mrope_section * 2): | ||
| axis_map += [i % 3] * n | ||
| self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda") | ||
shihaobai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, layer_weight): | ||
| input1 = self._att_norm(input_embdings, infer_state, layer_weight) | ||
| q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| self._post_cache_kv(cache_kv, infer_state, layer_weight) | ||
| o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) | ||
| q = None | ||
| o = self._get_o(o, infer_state, layer_weight) | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(o.view(-1, self.embed_dim_)) | ||
| o = None | ||
|
|
||
| input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) | ||
| ffn_out = self._ffn(input1, infer_state, layer_weight) | ||
| input1 = None | ||
| if self.tp_world_size_ > 1: | ||
| all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) | ||
| input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) | ||
| apply_deepstack_features( | ||
| input_embeddings=input_embdings, | ||
| infer_state=infer_state, | ||
| layer_num=self.layer_num_, | ||
| ) | ||
| return input_embdings | ||
36 changes: 36 additions & 0 deletions
36
lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import numpy as np | ||
| from lightllm.common.basemodel import PreAndPostLayerWeight | ||
|
|
||
|
|
||
| class Qwen3VLPreAndPostLayerWeight(PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| vob_size = self.network_config_["vocab_size"] | ||
| split_indexes = np.linspace(0, vob_size, self.tp_world_size_ + 1, dtype=np.int64) | ||
| split_start = split_indexes[self.tp_rank_] | ||
| split_end = split_indexes[self.tp_rank_ + 1] | ||
| if "model.language_model.embed_tokens.weight" in weights: | ||
| self.wte_weight_ = self._cuda(weights["model.language_model.embed_tokens.weight"][split_start:split_end, :]) | ||
| tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) | ||
| if tie_word_embeddings: | ||
| self.lm_head_weight_ = self.wte_weight_ | ||
| if "lm_head.weight" in weights: | ||
| self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) | ||
| if "model.language_model.norm.weight" in weights: | ||
| self.final_norm_weight_ = self._cuda(weights["model.language_model.norm.weight"]) | ||
|
|
||
| return | ||
|
|
||
| def verify_load(self): | ||
| errors = "weights load not ok" | ||
| weights = [ | ||
| self.wte_weight_, | ||
| self.lm_head_weight_, | ||
| self.final_norm_weight_, | ||
| ] | ||
| for i in range(len(weights)): | ||
| assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
| return |
30 changes: 30 additions & 0 deletions
30
lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight | ||
|
|
||
|
|
||
| class Qwen3VLTransformerLayerWeight(Qwen3TransformerLayerWeight): | ||
| def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): | ||
| super().__init__(layer_num, data_type, network_config, mode, quant_cfg) | ||
|
|
||
| def _init_weight_names(self): | ||
| super()._init_weight_names() | ||
| self._q_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_proj.weight" | ||
| self._q_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.q_norm.weight" | ||
| self._q_bias_name = None | ||
| self._k_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_proj.weight" | ||
| self._k_norm_name = f"model.language_model.layers.{self.layer_num_}.self_attn.k_norm.weight" | ||
| self._k_bias_name = None | ||
| self._v_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.v_proj.weight" | ||
| self._v_bias_name = None | ||
| self._o_weight_name = f"model.language_model.layers.{self.layer_num_}.self_attn.o_proj.weight" | ||
| self._o_bias_name = None | ||
| self._att_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.input_layernorm.weight" | ||
| self._att_norm_bias_name = None | ||
| self._ffn_norm_weight_name = f"model.language_model.layers.{self.layer_num_}.post_attention_layernorm.weight" | ||
| self._ffn_norm_bias_name = None | ||
|
|
||
| self._gate_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.gate_proj.weight" | ||
| self._gate_bias_name = None | ||
| self._up_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.up_proj.weight" | ||
| self._up_bias_name = None | ||
| self._down_weight_name = f"model.language_model.layers.{self.layer_num_}.mlp.down_proj.weight" | ||
| self._down_bias_name = None |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for
apply_interleaved_mropeis inaccurate. The argument isfreqswith a shape of(3, seq_len, head_dim // 2), but the docstring refers toxwith abs(batch size) dimension, which is not present. The return value shape is also(seq_len, head_dim // 2). Please update the docstring for clarity.