From 6b1cbe7ae113eb39c4041f7d982b3f59e80ed82c Mon Sep 17 00:00:00 2001 From: Yash Agarwal Date: Wed, 26 Nov 2025 14:19:34 -0500 Subject: [PATCH 1/4] Qwen3VL - Config, Conv Template, Basic structure --- 3rdparty/tvm | 2 +- .../mlc_llm/conversation_template/__init__.py | 1 + .../mlc_llm/conversation_template/qwen3_vl.py | 20 ++++ python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/model.py | 13 +++ python/mlc_llm/model/qwen3_vl/__init__.py | 0 .../mlc_llm/model/qwen3_vl/qwen3_vl_config.py | 109 ++++++++++++++++++ .../mlc_llm/model/qwen3_vl/qwen3_vl_loader.py | 9 ++ .../mlc_llm/model/qwen3_vl/qwen3_vl_model.py | 10 ++ .../model/qwen3_vl/qwen3_vl_quantization.py | 9 ++ python/mlc_llm/support/config.py | 4 +- 11 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 python/mlc_llm/conversation_template/qwen3_vl.py create mode 100644 python/mlc_llm/model/qwen3_vl/__init__.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py diff --git a/3rdparty/tvm b/3rdparty/tvm index c6fb2be79f..e16f5512aa 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c6fb2be79f654588fd94727a74b9ca0754f63fa4 +Subproject commit e16f5512aa635b6fa19cdb1ce94e25d22abca801 diff --git a/python/mlc_llm/conversation_template/__init__.py b/python/mlc_llm/conversation_template/__init__.py index 25a969caef..28eb94f26d 100644 --- a/python/mlc_llm/conversation_template/__init__.py +++ b/python/mlc_llm/conversation_template/__init__.py @@ -24,6 +24,7 @@ orion, phi, qwen2, + qwen3_vl, redpajama, rwkv, stablelm, diff --git a/python/mlc_llm/conversation_template/qwen3_vl.py b/python/mlc_llm/conversation_template/qwen3_vl.py new file mode 100644 index 0000000000..6bc19542c0 --- /dev/null +++ b/python/mlc_llm/conversation_template/qwen3_vl.py @@ -0,0 +1,20 @@ +"""Qwen3-VL default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Qwen3-VL +ConvTemplateRegistry.register_conv_template( + Conversation( + name="qwen3_vl", + system_template=f"<|im_start|>system\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\n", + system_message="You are a helpful assistant.", + roles={"user": "<|im_start|>user", "assistant": "<|im_start|>assistant"}, + seps=["<|im_end|>\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_str=["<|endoftext|>", "<|im_end|>"], + stop_token_ids=[151643, 151645], + ) +) diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index af24afbd9a..aa6af2988f 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -293,6 +293,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "chatml", "chatml_nosystem", "qwen2", + "qwen3_vl", "open_hermes_mistral", "neural_hermes_mistral", "llama_default", diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 875f6e11c4..1a97f2fd2e 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -42,6 +42,7 @@ from .qwen2_moe import qwen2_moe_loader, qwen2_moe_model, qwen2_moe_quantization from .qwen3 import qwen3_loader, qwen3_model, qwen3_quantization from .qwen3_moe import qwen3_moe_loader, qwen3_moe_model, qwen3_moe_quantization +from .qwen3_vl import qwen3_vl_loader, qwen3_vl_model, qwen3_vl_quantization from .rwkv5 import rwkv5_loader, rwkv5_model, rwkv5_quantization from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization @@ -374,6 +375,18 @@ class Model: "block-scale-quant": qwen3_moe_quantization.block_scale_quant, }, ), + "qwen3_vl": Model( + name="qwen3_vl", + model=qwen3_vl_model.Qwen3VLForConditionalGeneration, + config=qwen3_vl_model.Qwen3VLConfig, + source={ + "huggingface-torch": qwen3_vl_loader.huggingface, + "huggingface-safetensor": qwen3_vl_loader.huggingface, + }, + quantize={ + "no-quant": qwen3_vl_quantization.no_quant, + }, + ), "deepseek_v2": Model( name="deepseek_v2", model=deepseek_v2_model.DeepseekV2ForCausalLM, diff --git a/python/mlc_llm/model/qwen3_vl/__init__.py b/python/mlc_llm/model/qwen3_vl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py new file mode 100644 index 0000000000..8a08a4e0e8 --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py @@ -0,0 +1,109 @@ +""" +Configuration for Qwen3-VL model. +""" + +import dataclasses +from typing import Any, Dict, Optional, Tuple + +from mlc_llm.model.qwen3.qwen3_model import Qwen3Config +from mlc_llm.support import logging +from mlc_llm.support.config import ConfigBase + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Qwen3VLVisionConfig(ConfigBase): + """Configuration for the vision module of Qwen3-VL.""" + + depth: int + hidden_size: int + hidden_act: str + intermediate_size: int + num_heads: int + in_channels: int + patch_size: int + spatial_merge_size: int + temporal_patch_size: int + out_hidden_size: int + num_position_embeddings: int + deepstack_visual_indexes: list[int] + initializer_range: float = 0.02 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class Qwen3VLConfig(ConfigBase): + """Configuration for Qwen3-VL model.""" + + text_config: Qwen3Config + vision_config: Qwen3VLVisionConfig + image_token_id: int = 151655 + video_token_id: int = 151656 + vision_start_token_id: int = 151652 + vision_end_token_id: int = 151653 + tie_word_embeddings: bool = False + max_batch_size: int = 128 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + @property + def vocab_size(self) -> int: + return self.text_config.vocab_size + + @property + def prefill_chunk_size(self) -> int: + return self.text_config.prefill_chunk_size + + @property + def context_window_size(self) -> int: + return self.text_config.context_window_size + + @property + def tensor_parallel_shards(self) -> int: + return self.text_config.tensor_parallel_shards + + def __post_init__(self): + if isinstance(self.text_config, dict): + self.text_config = Qwen3Config.from_dict(self.text_config) + if isinstance(self.vision_config, dict): + self.vision_config = Qwen3VLVisionConfig.from_dict(self.vision_config) + + @classmethod + def from_huggingface(cls, config_json: Dict[str, Any]) -> "Qwen3VLConfig": + """Create Qwen3VLConfig from HuggingFace config.""" + # Extract text config + text_config_dict = config_json.get("text_config", {}) + # Ensure model_type is set correctly for Qwen3Config if needed, or just pass as is + # Qwen3Config might expect certain fields. + + # Extract vision config + vision_config_dict = config_json.get("vision_config", {}) + + # Extract top-level fields + image_token_id = config_json.get("image_token_id", 151655) + video_token_id = config_json.get("video_token_id", 151656) + vision_start_token_id = config_json.get("vision_start_token_id", 151652) + vision_end_token_id = config_json.get("vision_end_token_id", 151653) + tie_word_embeddings = config_json.get("tie_word_embeddings", False) + + return cls( + text_config=Qwen3Config.from_dict(text_config_dict), + vision_config=Qwen3VLVisionConfig.from_dict(vision_config_dict), + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + vision_end_token_id=vision_end_token_id, + tie_word_embeddings=tie_word_embeddings, + kwargs=config_json, + ) + +# Testing command +# conda activate tvm-dev +# export LOCAL_MODEL_PATH=../mlc-models/Qwen3-VL-2B-Instruct/ +# export MLC_MODEL_PATH=../mlc-models/mlc-qwen/ +# export QUANTIZATION=q0f16 +# export CONV_TEMPLATE=qwen3_vl +# python -m mlc_llm gen_config $LOCAL_MODEL_PATH \ +# --quantization $QUANTIZATION \ +# --conv-template $CONV_TEMPLATE \ +# -o $MLC_MODEL_PATH \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py new file mode 100644 index 0000000000..dc569bfeea --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py @@ -0,0 +1,9 @@ +""" +Minimal loader for Qwen3-VL. +""" +from typing import Any, Dict + +from mlc_llm.loader import Loader + +def huggingface(model_config, quantization): + return None diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py new file mode 100644 index 0000000000..5c7170025d --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py @@ -0,0 +1,10 @@ +""" +Minimal model for Qwen3-VL. +""" +from tvm.relax.frontend import nn +from .qwen3_vl_config import Qwen3VLConfig + +class Qwen3VLForConditionalGeneration(nn.Module): + def __init__(self, config: Qwen3VLConfig): + super().__init__() + self.config = config diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py new file mode 100644 index 0000000000..4078205267 --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py @@ -0,0 +1,9 @@ +""" +Minimal quantization for Qwen3-VL. +""" +from typing import Any, Dict, Tuple +from tvm.relax.frontend import nn +from mlc_llm.loader import QuantizeMapping + +def no_quant(model_config, quantization) -> Tuple[nn.Module, QuantizeMapping]: + return None, None diff --git a/python/mlc_llm/support/config.py b/python/mlc_llm/support/config.py index 715a4b2fa4..f854926f19 100644 --- a/python/mlc_llm/support/config.py +++ b/python/mlc_llm/support/config.py @@ -46,8 +46,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: An instance of the config object. """ field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] - fields = {k: v for k, v in source.items() if k in field_names} + fields = {k: v for k, v in source.items() if k in field_names and k != "kwargs"} kwargs = {k: v for k, v in source.items() if k not in field_names} + if "kwargs" in source and isinstance(source["kwargs"], dict): + kwargs.update(source["kwargs"]) return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] @classmethod From 3b47c043d513779c1c2f217910789635c91a5dbd Mon Sep 17 00:00:00 2001 From: Yash Agarwal Date: Mon, 8 Dec 2025 13:13:56 -0500 Subject: [PATCH 2/4] Qwen3VL - Add model definition, loader, quant and tests --- python/mlc_llm/model/model.py | 3 + python/mlc_llm/model/qwen3/qwen3_model.py | 4 + python/mlc_llm/model/qwen3_vl/qwen2_vl.py | 201 ++++++++++ .../mlc_llm/model/qwen3_vl/qwen3_vl_config.py | 102 ++++- .../mlc_llm/model/qwen3_vl/qwen3_vl_loader.py | 155 +++++++- .../mlc_llm/model/qwen3_vl/qwen3_vl_model.py | 357 +++++++++++++++++- .../model/qwen3_vl/qwen3_vl_quantization.py | 72 +++- .../mlc_llm/model/qwen3_vl/qwen3_vl_text.py | 342 +++++++++++++++++ .../mlc_llm/model/qwen3_vl/qwen3_vl_vision.py | 272 +++++++++++++ python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py | 67 ++++ .../data/qwen3_vl_2b_instruct/config.json | 63 ++++ tests/python/model/test_qwen3vl.py | 48 +++ tests/python/model/test_qwen3vl_loader.py | 41 ++ 13 files changed, 1709 insertions(+), 18 deletions(-) create mode 100644 python/mlc_llm/model/qwen3_vl/qwen2_vl.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_text.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen3_vl_vision.py create mode 100644 python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py create mode 100644 tests/python/model/data/qwen3_vl_2b_instruct/config.json create mode 100644 tests/python/model/test_qwen3vl.py create mode 100644 tests/python/model/test_qwen3vl_loader.py diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 1a97f2fd2e..e8c4b06a8f 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -385,6 +385,9 @@ class Model: }, quantize={ "no-quant": qwen3_vl_quantization.no_quant, + "group-quant": qwen3_vl_quantization.group_quant, + "ft-quant": qwen3_vl_quantization.ft_quant, + "block-scale-quant": qwen3_vl_quantization.block_scale_quant, }, ), "deepseek_v2": Model( diff --git a/python/mlc_llm/model/qwen3/qwen3_model.py b/python/mlc_llm/model/qwen3/qwen3_model.py index a4468ffe47..513a924f25 100644 --- a/python/mlc_llm/model/qwen3/qwen3_model.py +++ b/python/mlc_llm/model/qwen3/qwen3_model.py @@ -153,10 +153,14 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: "relu": nn.relu, "silu": nn.silu, "swish": nn.silu, + "gelu_new": partial(nn.gelu, approximate=True), + "quick_gelu": partial(nn.gelu, approximate=True), + "gelu_pytorch_tanh": partial(nn.gelu, approximate=True), } + class Qwen3Embedding(nn.Embedding): """The embedding module specialized for Qwen3 so that it can be shared with the final lm_head. diff --git a/python/mlc_llm/model/qwen3_vl/qwen2_vl.py b/python/mlc_llm/model/qwen3_vl/qwen2_vl.py new file mode 100644 index 0000000000..6b35fceedd --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen2_vl.py @@ -0,0 +1,201 @@ +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from typing import Optional, Tuple + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + + # TODO - i am assuming tvm has the same conv3d as pytorch + self.proj = nn.Conv3D(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: Tensor) -> Tensor: + + ''' + TODO - translate pytorch to tvm + ''' + + raise NotImplementedError + # target_dtype = self.proj.weight.dtype + # hidden_states = hidden_states.view( + # -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + # ) + # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + # return hidden_states + + + +class VisionRotaryEmbedding(nn.Module): + inv_freq: Tensor + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int) -> Tensor: + # TODO - assuming op.arange syntax == torch.arange, changed dtype to the string literal float32, idk how tvm does dtypes + self.inv_freq = 1.0 / (self.theta ** (op.arange(0, self.dim, 2, dtype="float32") / self.dim)) + pass + + # TODO - translate pytorch to tvm + raise NotImplementedError + # seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + # freqs = torch.outer(seq, self.inv_freq) + # return freqs + + +class VisionAttention(nn.Module): + + # fyi this expects a Qwen2VLVisionConfig + def __init__(self, config) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: Tensor, + cu_seqlens: Tensor, + rotary_pos_emb: Optional[Tensor] = None, + position_embeddings: Optional[tuple[Tensor, Tensor]] = None, + **kwargs, + ) -> Tensor: + + # TODO - translate pytorch to tvm + + raise NotImplementedError + + # seq_length = hidden_states.shape[0] + # query_states, key_states, value_states = ( + # self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + # ) + # cos, sin = position_embeddings + # query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + # query_states = query_states.transpose(0, 1).unsqueeze(0) + # key_states = key_states.transpose(0, 1).unsqueeze(0) + # value_states = value_states.transpose(0, 1).unsqueeze(0) + + # attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # if self.config._attn_implementation == "flash_attention_2": + # # Flash Attention 2: Use cu_seqlens for variable length attention + # max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + # attn_output, _ = attention_interface( + # self, + # query_states, + # key_states, + # value_states, + # attention_mask=None, + # scaling=self.scaling, + # dropout=0.0 if not self.training else self.attention_dropout, + # cu_seq_lens_q=cu_seqlens, + # cu_seq_lens_k=cu_seqlens, + # max_length_q=max_seqlen, + # max_length_k=max_seqlen, + # is_causal=False, + # **kwargs, + # ) + # else: + # # Other implementations: Process each chunk separately + # lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # splits = [ + # torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + # ] + + # attn_outputs = [ + # attention_interface( + # self, + # q, + # k, + # v, + # attention_mask=None, + # scaling=self.scaling, + # dropout=0.0 if not self.training else self.attention_dropout, + # is_causal=False, + # **kwargs, + # )[0] + # for q, k, v in zip(*splits) + # ] + # attn_output = torch.cat(attn_outputs, dim=1) + + # attn_output = attn_output.reshape(seq_length, -1).contiguous() + # attn_output = self.proj(attn_output) + # return attn_output + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + + # fyi assuming nn.Parameter is a thing + + # do i need to have nn.Parameter, or can it just be op.ones? + self.weight = nn.Parameter((hidden_size,), dtype="float32") + self.variance_epsilon = eps + + def forward(self, hidden_states: Tensor) -> Tensor: + # TODO - translate pytorch to tvm + + raise NotImplementedError + # input_dtype = hidden_states.dtype + # hidden_states = hidden_states.to("float32") + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + + +class Qwen2VLModel(nn.Module): + base_model_prefix = "model" + accepts_loss_kwargs = False + + # expects qwen2vlconfig object + def __init__(self, config): + self.rope_deltas = None # cache rope_deltas here + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_rope_index( + self, + input_ids: Optional[Tensor] = None, + image_grid_thw: Optional[Tensor] = None, + video_grid_thw: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor]: + raise NotImplementedError \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py index 8a08a4e0e8..3b358c84c5 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_config.py @@ -5,9 +5,9 @@ import dataclasses from typing import Any, Dict, Optional, Tuple -from mlc_llm.model.qwen3.qwen3_model import Qwen3Config from mlc_llm.support import logging from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold logger = logging.getLogger(__name__) @@ -32,11 +32,98 @@ class Qwen3VLVisionConfig(ConfigBase): kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) +@dataclasses.dataclass +class Qwen3VLTextConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Qwen3-VL text model.""" + + hidden_act: str + hidden_size: int + intermediate_size: int + attention_bias: bool + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + rms_norm_eps: float + vocab_size: int + rope_theta: float = 500000.0 + tie_word_embeddings: bool = False + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + head_dim: int = 0 + dtype: str = "float32" + max_batch_size: int = 1 + weight_block_size: Optional[Tuple[int, int]] = None + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if "quantization_config" in self.kwargs: + quantization_config = self.kwargs.get("quantization_config") + if ( + isinstance(quantization_config, dict) + and quantization_config.get("activation_scheme", "") == "dynamic" + and quantization_config.get("fmt", "") == "e4m3" + and quantization_config.get("quant_method", "") == "fp8" + and "weight_block_size" in quantization_config + ): + self.weight_block_size = quantization_config.get("weight_block_size") + if ( + not isinstance(self.weight_block_size, (tuple, list)) + or len(self.weight_block_size) != 2 + ): + raise ValueError( + "Invalid DeepSeek model quantization config: " + "weight_block_size must be a tuple of two integers, " + f"got {self.weight_block_size} of type {type(self.weight_block_size)}" + ) + else: + raise ValueError( + "Invalid DeepSeek model quantization config: unrecognized quantization config: " + f"{quantization_config}" + ) + + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + # Default to 128000 for Qwen3-VL text if not found + self.context_window_size = 128000 + logger.info( + "%s not found in config.json. Falling back to default %d", + bold("context_window_size"), + self.context_window_size, + ) + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 2048), + ) + self.prefill_chunk_size = min(self.context_window_size, 2048) + + @dataclasses.dataclass class Qwen3VLConfig(ConfigBase): """Configuration for Qwen3-VL model.""" - text_config: Qwen3Config + text_config: Qwen3VLTextConfig vision_config: Qwen3VLVisionConfig image_token_id: int = 151655 video_token_id: int = 151656 @@ -64,7 +151,7 @@ def tensor_parallel_shards(self) -> int: def __post_init__(self): if isinstance(self.text_config, dict): - self.text_config = Qwen3Config.from_dict(self.text_config) + self.text_config = Qwen3VLTextConfig.from_dict(self.text_config) if isinstance(self.vision_config, dict): self.vision_config = Qwen3VLVisionConfig.from_dict(self.vision_config) @@ -73,8 +160,6 @@ def from_huggingface(cls, config_json: Dict[str, Any]) -> "Qwen3VLConfig": """Create Qwen3VLConfig from HuggingFace config.""" # Extract text config text_config_dict = config_json.get("text_config", {}) - # Ensure model_type is set correctly for Qwen3Config if needed, or just pass as is - # Qwen3Config might expect certain fields. # Extract vision config vision_config_dict = config_json.get("vision_config", {}) @@ -87,7 +172,7 @@ def from_huggingface(cls, config_json: Dict[str, Any]) -> "Qwen3VLConfig": tie_word_embeddings = config_json.get("tie_word_embeddings", False) return cls( - text_config=Qwen3Config.from_dict(text_config_dict), + text_config=Qwen3VLTextConfig.from_dict(text_config_dict), vision_config=Qwen3VLVisionConfig.from_dict(vision_config_dict), image_token_id=image_token_id, video_token_id=video_token_id, @@ -103,7 +188,4 @@ def from_huggingface(cls, config_json: Dict[str, Any]) -> "Qwen3VLConfig": # export MLC_MODEL_PATH=../mlc-models/mlc-qwen/ # export QUANTIZATION=q0f16 # export CONV_TEMPLATE=qwen3_vl -# python -m mlc_llm gen_config $LOCAL_MODEL_PATH \ -# --quantization $QUANTIZATION \ -# --conv-template $CONV_TEMPLATE \ -# -o $MLC_MODEL_PATH \ No newline at end of file +# python -m mlc_llm gen_config $LOCAL_MODEL_PATH --quantization $QUANTIZATION --conv-template $CONV_TEMPLATE -o $MLC_MODEL_PATH \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py index dc569bfeea..0f70bdf10e 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_loader.py @@ -1,9 +1,156 @@ """ Minimal loader for Qwen3-VL. """ -from typing import Any, Dict +import functools +from typing import Callable, List -from mlc_llm.loader import Loader +import numpy as np +from mlc_llm.loader import ExternMapping, QuantizeMapping +from mlc_llm.quantization import BlockScaleQuantize, Quantization +from mlc_llm.model.qwen3_vl.qwen3_vl_config import Qwen3VLConfig +from mlc_llm.model.qwen3_vl.qwen3_vl_model import Qwen3VLForConditionalGeneration -def huggingface(model_config, quantization): - return None + +def huggingface(model_config: Qwen3VLConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : Qwen3VLConfig + The configuration of the Qwen3-VL model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = Qwen3VLForConditionalGeneration(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + + if isinstance(quantization, BlockScaleQuantize): + # Convert the model to block-scale quantized model before loading parameters + model = quantization.quantize_model(model, QuantizeMapping({}, {}), "") + if model_config.text_config.weight_block_size is None: + raise ValueError( + "The input Qwen3-VL model is not fp8 block quantized. " + "Thus BlockScaleQuantize is not supported." + ) + + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + if ( + not isinstance(quantization, BlockScaleQuantize) + and model_config.text_config.weight_block_size is not None + ): + raise ValueError( + "The input Qwen3-VL model is fp8 block quantized. " + "Please use BlockScaleQuantize for the model." + ) + + # Helper function to add both weight and scale mappings + def add_weight_and_scale_mapping( + weight_mlc_name: str, + weight_hf_names: List[str], + weight_transform_func: Callable, + ): + if weight_mlc_name not in named_parameters: + return + + mlc_param = named_parameters[weight_mlc_name] + mapping.add_mapping( + weight_mlc_name, + weight_hf_names, + functools.partial(weight_transform_func, dtype=mlc_param.dtype), + ) + + if isinstance(quantization, BlockScaleQuantize): + scale_mlc_name = f"{weight_mlc_name}_scale_inv" + if scale_mlc_name in named_parameters: + scale_hf_names = [f"{name}_scale_inv" for name in weight_hf_names] + scale_param = named_parameters[scale_mlc_name] + mapping.add_mapping( + scale_mlc_name, + scale_hf_names, + functools.partial(weight_transform_func, dtype=scale_param.dtype), + ) + + # ========================== + # Text Model Mapping + # ========================== + prefix = "model.language_model" + + for i in range(model_config.text_config.num_hidden_layers): + # map attention weight + attn_mlc = f"{prefix}.layers.{i}.self_attn" + attn_hf = f"{prefix}.layers.{i}.self_attn" + + # Merge Q, K, V + add_weight_and_scale_mapping( + f"{attn_mlc}.c_attn.weight", + [ + f"{attn_hf}.q_proj.weight", + f"{attn_hf}.k_proj.weight", + f"{attn_hf}.v_proj.weight", + ], + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + ) + + if model_config.text_config.attention_bias: + mlc_name = f"{attn_mlc}.c_attn.bias" + if mlc_name in named_parameters: + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn_hf}.q_proj.bias", + f"{attn_hf}.k_proj.bias", + f"{attn_hf}.v_proj.bias", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # map mlp weight + mlp_mlc = f"{prefix}.layers.{i}.mlp" + mlp_hf = f"{prefix}.layers.{i}.mlp" + + # Merge Gate, Up + add_weight_and_scale_mapping( + f"{mlp_mlc}.gate_up_proj.weight", + [ + f"{mlp_hf}.gate_proj.weight", + f"{mlp_hf}.up_proj.weight", + ], + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + ) + + # ========================== + # Vision Model Mapping + # ========================== + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + # Check if this is a vision parameter or text parameter that maps 1:1 + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + return mapping \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py index 5c7170025d..7615a76a86 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_model.py @@ -2,9 +2,364 @@ Minimal model for Qwen3-VL. """ from tvm.relax.frontend import nn -from .qwen3_vl_config import Qwen3VLConfig +from tvm.relax.frontend.nn import Tensor, op + +from .qwen3_vl_config import Qwen3VLConfig, Qwen3VLVisionConfig +from .qwen_2_5_vl import Qwen2_5_VLModel + +from .qwen3_vl_vision import Qwen3VLVisionModel +from .qwen3_vl_text import Qwen3VLTextModel + +from typing import Optional, Union +from tvm import tir +from mlc_llm.nn import PagedKVCache, RopeMode + + +class Qwen3VLModel(Qwen2_5_VLModel): + config: Qwen3VLConfig + base_model_prefix = "model" + _checkpoint_conversion_mapping = {} + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel(config.vision_config) + self.language_model = Qwen3VLTextModel(config.text_config) + + def get_rope_index( + self, + input_ids: Optional[Tensor] = None, + image_grid_thw: Optional[Tensor] = None, + video_grid_thw: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + ) -> tuple[Tensor, Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # TODO translate pytorch to tvm + raise NotImplementedError + + # # Since we use timestamps to separate videos, like , the video_grid_thw should also be split + # if video_grid_thw is not None: + # video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + # video_grid_thw[:, 0] = 1 + + # spatial_merge_size = self.config.vision_config.spatial_merge_size + # image_token_id = self.config.image_token_id + # video_token_id = self.config.video_token_id + # vision_start_token_id = self.config.vision_start_token_id + # mrope_position_deltas = [] + # if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + # total_input_ids = input_ids + # if attention_mask is None: + # attention_mask = torch.ones_like(total_input_ids) + # position_ids = torch.ones( + # 3, + # input_ids.shape[0], + # input_ids.shape[1], + # dtype=input_ids.dtype, + # device=input_ids.device, + # ) + # image_index, video_index = 0, 0 + # attention_mask = attention_mask.to(total_input_ids.device) + # for i, input_ids in enumerate(total_input_ids): + # input_ids = input_ids[attention_mask[i] == 1] + # image_nums, video_nums = 0, 0 + # vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + # vision_tokens = input_ids[vision_start_indices + 1] + # image_nums = (vision_tokens == image_token_id).sum() + # video_nums = (vision_tokens == video_token_id).sum() + # input_tokens = input_ids.tolist() + # llm_pos_ids_list: list = [] + # st = 0 + # remain_images, remain_videos = image_nums, video_nums + # for _ in range(image_nums + video_nums): + # if image_token_id in input_tokens and remain_images > 0: + # ed_image = input_tokens.index(image_token_id, st) + # else: + # ed_image = len(input_tokens) + 1 + # if video_token_id in input_tokens and remain_videos > 0: + # ed_video = input_tokens.index(video_token_id, st) + # else: + # ed_video = len(input_tokens) + 1 + # if ed_image < ed_video: + # t, h, w = ( + # image_grid_thw[image_index][0], + # image_grid_thw[image_index][1], + # image_grid_thw[image_index][2], + # ) + # image_index += 1 + # remain_images -= 1 + # ed = ed_image + + # else: + # t, h, w = ( + # video_grid_thw[video_index][0], + # video_grid_thw[video_index][1], + # video_grid_thw[video_index][2], + # ) + # video_index += 1 + # remain_videos -= 1 + # ed = ed_video + # llm_grid_t, llm_grid_h, llm_grid_w = ( + # t.item(), + # h.item() // spatial_merge_size, + # w.item() // spatial_merge_size, + # ) + # text_len = ed - st + + # st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + # llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + # t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + # h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + # w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + # llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + # st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + # if st < len(input_tokens): + # st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + # text_len = len(input_tokens) - st + # llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + # position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + # mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + # mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + # return position_ids, mrope_position_deltas + # else: + # if attention_mask is not None: + # position_ids = attention_mask.long().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask == 0, 1) + # position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + # max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + # mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + # else: + # position_ids = ( + # torch.arange(input_ids.shape[1], device=input_ids.device) + # .view(1, 1, -1) + # .expand(3, input_ids.shape[0], -1) + # ) + # mrope_position_deltas = torch.zeros( + # [input_ids.shape[0], 1], + # device=input_ids.device, + # dtype=input_ids.dtype, + # ) + + # return position_ids, mrope_position_deltas + + def get_image_features(self, pixel_values: Tensor, image_grid_thw: Optional[Tensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`Tensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + raise NotImplementedError + # pixel_values = pixel_values.type(self.visual.dtype) + # image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + # split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + # image_embeds = torch.split(image_embeds, split_sizes) + # return image_embeds, deepstack_image_embeds + + def get_video_features( + self, pixel_values_videos: Tensor, video_grid_thw: Optional[Tensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + + def forward( + self, + input_ids: Tensor = None, + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_values: Optional["Cache"] = None, + inputs_embeds: Optional[Tensor] = None, + pixel_values: Optional[Tensor] = None, + pixel_values_videos: Optional[Tensor] = None, + image_grid_thw: Optional[Tensor] = None, + video_grid_thw: Optional[Tensor] = None, + cache_position: Optional[Tensor] = None, + **kwargs, + ): + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + raise NotImplementedError + # if (input_ids is None) ^ (inputs_embeds is not None): + # raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # if inputs_embeds is None: + # inputs_embeds = self.get_input_embeddings()(input_ids) + + # image_mask = None + # video_mask = None + + # if pixel_values is not None: + # image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + # image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + # image_mask, _ = self.get_placeholder_mask( + # input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + # ) + # inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + # if pixel_values_videos is not None: + # video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + # video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + # _, video_mask = self.get_placeholder_mask( + # input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + # ) + # inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # visual_pos_masks = None + # deepstack_visual_embeds = None + # if image_mask is not None and video_mask is not None: + # # aggregate visual_pos_masks and deepstack_visual_embeds + # image_mask = image_mask[..., 0] + # video_mask = video_mask[..., 0] + # visual_pos_masks = image_mask | video_mask + # deepstack_visual_embeds = [] + # image_mask_joint = image_mask[visual_pos_masks] + # video_mask_joint = video_mask[visual_pos_masks] + # for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + # embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + # embed_joint[image_mask_joint, :] = img_embed + # embed_joint[video_mask_joint, :] = vid_embed + # deepstack_visual_embeds.append(embed_joint) + # elif image_mask is not None: + # image_mask = image_mask[..., 0] + # visual_pos_masks = image_mask + # deepstack_visual_embeds = deepstack_image_embeds + # elif video_mask is not None: + # video_mask = video_mask[..., 0] + # visual_pos_masks = video_mask + # deepstack_visual_embeds = deepstack_video_embeds + + # if position_ids is None: + # past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length() + # if self.rope_deltas is None or past_key_values_length == 0: + # position_ids, rope_deltas = self.get_rope_index( + # input_ids, + # image_grid_thw, + # video_grid_thw, + # attention_mask=attention_mask, + # ) + # self.rope_deltas = rope_deltas + # # then use the prev pre-calculated rope-deltas to get the correct position ids + # else: + # batch_size, seq_length, _ = inputs_embeds.shape + # delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device) + # position_ids = torch.arange(seq_length, device=inputs_embeds.device) + # position_ids = position_ids.view(1, -1).expand(batch_size, -1) + # if cache_position is not None: # otherwise `deltas` is an int `0` + # delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + # position_ids = position_ids.add(delta) + # position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # outputs = self.language_model( + # input_ids=None, + # position_ids=position_ids, + # attention_mask=attention_mask, + # past_key_values=past_key_values, + # inputs_embeds=inputs_embeds, + # cache_position=cache_position, + # visual_pos_masks=visual_pos_masks, + # deepstack_visual_embeds=deepstack_visual_embeds, + # **kwargs, + # ) + + # return Qwen3VLModelOutputWithPast( + # last_hidden_state=outputs.last_hidden_state, + # past_key_values=outputs.past_key_values, + # rope_deltas=self.rope_deltas, + # ) + class Qwen3VLForConditionalGeneration(nn.Module): def __init__(self, config: Qwen3VLConfig): super().__init__() self.config = config + self.model = Qwen3VLModel(config) + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + b, s, d = input_embed.shape + return op.zeros((b, s, self.config.text_config.vocab_size), dtype="float32"), paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + b, s, d = input_embed.shape + return op.zeros((b, s, self.config.text_config.vocab_size), dtype="float32"), paged_kv_cache + + def create_paged_kv_cache( # pylint: disable=too-many-arguments + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var, + ) -> PagedKVCache: + return PagedKVCache.create_generic( + attn_kind="mha", + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.config.text_config.num_hidden_layers, + num_attention_heads=self.config.text_config.num_attention_heads // self.config.text_config.tensor_parallel_shards, + num_key_value_heads=self.config.text_config.num_key_value_heads // self.config.text_config.tensor_parallel_shards, + qk_head_dim=self.config.text_config.head_dim, + v_head_dim=self.config.text_config.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.config.text_config.rope_theta, + dtype=self.config.text_config.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.config.text_config.hidden_size], "float32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.config.text_config.hidden_size], "float32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) + diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py index 4078205267..c27b772e2c 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_quantization.py @@ -1,9 +1,75 @@ """ Minimal quantization for Qwen3-VL. """ -from typing import Any, Dict, Tuple +from typing import Tuple + from tvm.relax.frontend import nn + from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import ( + BlockScaleQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, +) + +from .qwen3_vl_config import Qwen3VLConfig +from .qwen3_vl_model import Qwen3VLForConditionalGeneration + + +def group_quant( + model_config: Qwen3VLConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen3-VL model using group quantization.""" + model: nn.Module = Qwen3VLForConditionalGeneration(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + + quantization.tensor_parallel_shards = model_config.text_config.tensor_parallel_shards + + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: Qwen3VLConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen3-VL model using FasterTransformer quantization.""" + model: nn.Module = Qwen3VLForConditionalGeneration(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: Qwen3VLConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen3-VL model without quantization.""" + model = Qwen3VLForConditionalGeneration(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map + -def no_quant(model_config, quantization) -> Tuple[nn.Module, QuantizeMapping]: - return None, None +def block_scale_quant( + model_config: Qwen3VLConfig, + quantization: BlockScaleQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Qwen3-VL model using block-scale quantization.""" + model: nn.Module = Qwen3VLForConditionalGeneration(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model(model, quant_map, "") + return model, quant_map diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_text.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_text.py new file mode 100644 index 0000000000..bcf53fe2d6 --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_text.py @@ -0,0 +1,342 @@ +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from .qwen3_vl_config import Qwen3VLTextConfig + +from mlc_llm.model.qwen3.qwen3_model import Qwen3Attention, Qwen3DecoderLayer, Qwen3Model + +from typing import Optional + +class LlamaRotaryEmbedding(nn.Module): + inv_freq: Tensor + + # fyi config is of type LlamaConfig + def __init__(self, config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + + assert self.rope_type == "default", f"Unsupported rope type {self.rope_type}" + inv_freq, self.attention_scaling = self.compute_default_rope_parameters(self.config, device) + + self.original_inv_freq = inv_freq + + # fyi config is of type LlamaConfig + @staticmethod + def compute_default_rope_parameters( + config, + device: Optional[str] = None, + seq_len: Optional[int] = None, + ) -> tuple[Tensor, float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (op.arange(0, dim, 2, dtype="int64").to(device=device, dtype="float32") / dim) + ) + return inv_freq, attention_factor + + def forward(self, x, position_ids): + # TODO: translate from pytorch to tvm + raise NotImplementedError + + # inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + # position_ids_expanded = position_ids[:, None, :].float() + + # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + # with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + # emb = torch.cat((freqs, freqs), dim=-1) + # cos = emb.cos() * self.attention_scaling + # sin = emb.sin() * self.attention_scaling + + # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3VLTextRotaryEmbedding(LlamaRotaryEmbedding): + inv_freq: Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__(config, device=device) + + self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + + """ + + # TODO: translate from pytorch to tvm + raise NotImplementedError + + # freqs_t = freqs[0] # just overwrite the first dimension T + # for dim, offset in enumerate((1, 2), start=1): # H, W + # length = mrope_section[dim] * 3 + # idx = slice(offset, length, 3) + # freqs_t[..., idx] = freqs[dim, ..., idx] + # return freqs_t + + def forward(self, x, position_ids): + + # TODO: translate from pytorch to tvm + raise NotImplementedError + + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + + + # if position_ids.ndim == 2: + # position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + # inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + # position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + # device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + # with torch.autocast(device_type=device_type, enabled=False): # Force float32 + # freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + # freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + # emb = torch.cat((freqs, freqs), dim=-1) + # cos = emb.cos() * self.attention_scaling + # sin = emb.sin() * self.attention_scaling + + # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3VLTextAttention(Qwen3Attention): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + + # TODO - pytorch qwen3attention takes a layer_idx but mlc doesnt, check if we really need it + # hazardous - qwen3attention expects qwen3config, we are passing qwen3vltextconfig + super().__init__(config) + + # no sliding window in mlc qwen3attention? + # del self.sliding_window + + def forward( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor], + past_key_values: Optional["Cache"] = None, + cache_position: Optional[Tensor] = None, + **kwargs, + ) -> tuple[Tensor, Optional[Tensor]]: + # TODO: translate from pytorch to tvm + raise NotImplementedError + + # input_shape = hidden_states.shape[:-1] + # hidden_shape = (*input_shape, -1, self.head_dim) + + # query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + # key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + # value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + # cos, sin = position_embeddings + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # if past_key_values is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # attn_output, attn_weights = attention_interface( + # self, + # query_states, + # key_states, + # value_states, + # attention_mask, + # dropout=0.0 if not self.training else self.attention_dropout, + # scaling=self.scaling, + # **kwargs, + # ) + + # attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # attn_output = self.o_proj(attn_output) + # return attn_output, attn_weights + + +class Qwen3VLTextDecoderLayer(Qwen3DecoderLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + # TODO - pytorch qwen3attention takes a layer_idx but mlc doesnt, check if we really need it + # hazardous - qwen3attention expects qwen3config, we are passing qwen3vltextconfig + super().__init__(config) + + # no attention_type in mlc qwen3decoderlayer? + #del self.attention_type + + def forward( + self, + hidden_states: Tensor, + position_embeddings: tuple[Tensor, Tensor], + attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + past_key_values: Optional["Cache"] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + # TODO: translate from pytorch to tvm + + # do we even need this class if the forward is just a call to super().forward? + raise NotImplementedError + + # return super().forward( + # hidden_states=hidden_states, + # position_embeddings=position_embeddings, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_values=past_key_values, + # use_cache=use_cache, + # cache_position=cache_position, + # **kwargs, + # ) + + +class Qwen3VLTextModel(Qwen3Model): + config: Qwen3VLTextConfig + + def __init__(self, config: Qwen3VLTextConfig): + + # hazardous - qwen3model expects qwen3config, we are passing qwen3vltextconfig + super().__init__(config) + + # no has_sliding_layers in mlc qwen3model? + + #del self.has_sliding_layers + + def _deepstack_process( + self, hidden_states: Tensor, visual_pos_masks: Tensor, visual_embeds: Tensor + ): + raise NotImplementedError + + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + local_this = hidden_states[visual_pos_masks, :] + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + def forward(self): + raise NotImplementedError + + # def forward( + # self, + # input_ids: Optional[torch.LongTensor] = None, + # attention_mask: Optional[torch.Tensor] = None, + # position_ids: Optional[torch.LongTensor] = None, + # past_key_values: Optional[Cache] = None, + # inputs_embeds: Optional[torch.FloatTensor] = None, + # use_cache: Optional[bool] = None, + # cache_position: Optional[torch.LongTensor] = None, + # # args for deepstack + # visual_pos_masks: Optional[torch.Tensor] = None, + # deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + # **kwargs: Unpack[FlashAttentionKwargs], + # ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen3_vl_vision.py b/python/mlc_llm/model/qwen3_vl/qwen3_vl_vision.py new file mode 100644 index 0000000000..da537f8a5a --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen3_vl_vision.py @@ -0,0 +1,272 @@ +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from .qwen3_vl_config import Qwen3VLVisionConfig + + +from mlc_llm.model.qwen3.qwen3_model import ACT2FN +from .qwen2_vl import PatchEmbed, VisionRotaryEmbedding, VisionAttention +from .qwen_2_5_vl import Qwen2_5_VLVisionBlock + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(PatchEmbed): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + # TODO - i am assuming tvm has the same conv3d as pytorch + self.proj = nn.Conv3D(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + +class Qwen3VLVisionRotaryEmbedding(VisionRotaryEmbedding): + pass + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: Tensor) -> Tensor: + # TODO - translate pytorch to tvm + raise NotImplementedError + # x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + # x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + # return x + + +class Qwen3VLVisionAttention(VisionAttention): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + + # fyi this is weird because the VisionAttention class expects a Qwen2VLVisionConfig param, but hf implementation passes nothing to it? + super().__init__(config) + self.dim = config.hidden_size + + +class Qwen3VLVisionBlock(Qwen2_5_VLVisionBlock): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__(config) + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + +class Qwen3VLVisionModel(nn.Module): + config: Qwen3VLVisionConfig + + def __init__(self, config, *inputs, **kwargs) -> None: + #super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: Tensor) -> Tensor: + # TODO - translate from pytorch to tvm + raise NotImplementedError + + # merge_size = self.spatial_merge_size + + # max_hw = int(grid_thw[:, 1:].max().item()) + # freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + # device = freq_table.device + + # total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + # pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + # offset = 0 + # for num_frames, height, width in grid_thw: + # merged_h, merged_w = height // merge_size, width // merge_size + + # block_rows = torch.arange(merged_h, device=device) # block row indices + # block_cols = torch.arange(merged_w, device=device) # block col indices + # intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + # intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # # Compute full-resolution positions + # row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + # col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + # row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + # col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + # coords = torch.stack((row_idx, col_idx), dim=-1) + + # if num_frames > 1: + # coords = coords.repeat(num_frames, 1) + + # num_tokens = coords.shape[0] + # pos_ids[offset : offset + num_tokens] = coords + # offset += num_tokens + + # embeddings = freq_table[pos_ids] # lookup rotary embeddings + # embeddings = embeddings.flatten(1) + # return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + # TODO - translate from pytorch to tvm + raise NotImplementedError + + # grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + # device = self.pos_embed.weight.device + + # idx_list = [[] for _ in range(4)] + # weight_list = [[] for _ in range(4)] + + # for t, h, w in zip(grid_ts, grid_hs, grid_ws): + # h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + # w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + # h_idxs_floor = h_idxs.int() + # w_idxs_floor = w_idxs.int() + # h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + # w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + # dh = h_idxs - h_idxs_floor + # dw = w_idxs - w_idxs_floor + + # base_h = h_idxs_floor * self.num_grid_per_side + # base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + # indices = [ + # (base_h[None].T + w_idxs_floor[None]).flatten(), + # (base_h[None].T + w_idxs_ceil[None]).flatten(), + # (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + # (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + # ] + + # weights = [ + # ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + # ((1 - dh)[None].T * dw[None]).flatten(), + # (dh[None].T * (1 - dw)[None]).flatten(), + # (dh[None].T * dw[None]).flatten(), + # ] + + # for i in range(4): + # idx_list[i].extend(indices[i].tolist()) + # weight_list[i].extend(weights[i].tolist()) + + # idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + # weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + # pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + # patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + # patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + # patch_pos_embeds_permute = [] + # merge_size = self.config.spatial_merge_size + # for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + # pos_embed = pos_embed.repeat(t, 1) + # pos_embed = ( + # pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + # .permute(0, 1, 3, 2, 4, 5) + # .flatten(0, 4) + # ) + # patch_pos_embeds_permute.append(pos_embed) + # patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + # return patch_pos_embeds + + def forward(self, hidden_states: Tensor, grid_thw: Tensor, **kwargs) -> Tensor: + """ + Args: + hidden_states (`Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + # TODO - translate from pytorch to tvm + raise NotImplementedError + + # hidden_states = self.patch_embed(hidden_states) + + # pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + # hidden_states = hidden_states + pos_embeds + + # rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # seq_len, _ = hidden_states.size() + # hidden_states = hidden_states.reshape(seq_len, -1) + # rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + # emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + # position_embeddings = (emb.cos(), emb.sin()) + + # cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + # dim=0, + # # Select dtype based on the following factors: + # # - FA2 requires that cu_seqlens_q must have dtype int32 + # # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # # See https://github.com/huggingface/transformers/pull/34852 for more information + # dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + # ) + # cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # deepstack_feature_lists = [] + # for layer_num, blk in enumerate(self.blocks): + # hidden_states = blk( + # hidden_states, + # cu_seqlens=cu_seqlens, + # position_embeddings=position_embeddings, + # **kwargs, + # ) + # if layer_num in self.deepstack_visual_indexes: + # deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + # hidden_states + # ) + # deepstack_feature_lists.append(deepstack_feature) + + # hidden_states = self.merger(hidden_states) + + # return hidden_states, deepstack_feature_lists \ No newline at end of file diff --git a/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py b/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py new file mode 100644 index 0000000000..44c89e8f6a --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py @@ -0,0 +1,67 @@ +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from .qwen2_vl import Qwen2RMSNorm, VisionAttention, Qwen2VLModel +from mlc_llm.model.qwen3.qwen3_model import ACT2FN + +from typing import Optional + +class Qwen2_5_VLVisionAttention(VisionAttention): + def __init__(self, config) -> None: + super().__init__(config) + self.dim = config.hidden_size + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + # TODO - translate pytorch to tvm + pass + # return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen2_5_VLVisionAttention(config=config) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: Tensor, + cu_seqlens: Tensor, + rotary_pos_emb: Optional[Tensor] = None, + position_embeddings: Optional[tuple[Tensor, Tensor]] = None, + **kwargs, + ) -> Tensor: + # TODO - translate pytorch to tvm + + pass + # hidden_states = hidden_states + self.attn( + # self.norm1(hidden_states), + # cu_seqlens=cu_seqlens, + # rotary_pos_emb=rotary_pos_emb, + # position_embeddings=position_embeddings, + # **kwargs, + # ) + # hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + # return hidden_states + + +class Qwen2_5_VLModel(Qwen2VLModel): + # config + base_model_prefix = "model" + accepts_loss_kwargs = False + + def __init__(self, config): + super().__init__(config) \ No newline at end of file diff --git a/tests/python/model/data/qwen3_vl_2b_instruct/config.json b/tests/python/model/data/qwen3_vl_2b_instruct/config.json new file mode 100644 index 0000000000..0cd8c646eb --- /dev/null +++ b/tests/python/model/data/qwen3_vl_2b_instruct/config.json @@ -0,0 +1,63 @@ +{ + "architectures": [ + "Qwen3VLForConditionalGeneration" + ], + "image_token_id": 151655, + "model_type": "qwen3_vl", + "text_config": { + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 6144, + "max_position_embeddings": 262144, + "model_type": "qwen3_vl_text", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_interleaved": true, + "mrope_section": [ + 24, + 20, + 20 + ], + "rope_type": "default" + }, + "rope_theta": 5000000, + "tie_word_embeddings": true, + "use_cache": true, + "vocab_size": 151936 + }, + "tie_word_embeddings": true, + "transformers_version": "4.57.0.dev0", + "video_token_id": 151656, + "vision_config": { + "deepstack_visual_indexes": [ + 5, + 11, + 17 + ], + "depth": 24, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1024, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 4096, + "model_type": "qwen3_vl", + "num_heads": 16, + "num_position_embeddings": 2304, + "out_hidden_size": 2048, + "patch_size": 16, + "spatial_merge_size": 2, + "temporal_patch_size": 2 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652 +} diff --git a/tests/python/model/test_qwen3vl.py b/tests/python/model/test_qwen3vl.py new file mode 100644 index 0000000000..1565c2c71b --- /dev/null +++ b/tests/python/model/test_qwen3vl.py @@ -0,0 +1,48 @@ + +# pylint: disable=invalid-name,missing-docstring +import json +import os +import pytest +from tvm import relax + +from mlc_llm.model import MODELS +from mlc_llm.model.qwen3_vl.qwen3_vl_config import Qwen3VLConfig + +# Directory containing the test config +TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data", "qwen3_vl_2b_instruct") +CONFIG_PATH = os.path.join(TEST_DATA_DIR, "config.json") + +@pytest.mark.skipif(not os.path.exists(CONFIG_PATH), reason="Test config not found") +def test_qwen3vl_creation(): + # Load config from file + with open(CONFIG_PATH, "r", encoding="utf-8") as f: + config_dict = json.load(f) + + # Instantiate Qwen3VLConfig + config = Qwen3VLConfig.from_dict(config_dict) + + # Get model info and class + model_info = MODELS["qwen3_vl"] + model_class = model_info.model + + # Create model + model = model_class(config) + + # Export to TVM to verify structure and creation + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), + ) + + # Basic assertions + import tvm + assert isinstance(mod, tvm.IRModule) + + assert len(named_params) > 0 + + # Verify some parameter shapes/types if needed, or just that it didn't crash + print("Qwen3-VL Model created successfully.") + for name, param in named_params: + print(f"{name}: {param.shape} {param.dtype}") + +if __name__ == "__main__": + test_qwen3vl_creation() diff --git a/tests/python/model/test_qwen3vl_loader.py b/tests/python/model/test_qwen3vl_loader.py new file mode 100644 index 0000000000..3e642baefa --- /dev/null +++ b/tests/python/model/test_qwen3vl_loader.py @@ -0,0 +1,41 @@ +import os +import subprocess + +def run_command(command): + print(f"Running: {command}") + process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + while True: + output = process.stdout.readline() + if output == '' and process.poll() is not None: + break + if output: + print(output.strip()) + rc = process.poll() + return rc + +def verify(): + local_model_path = "../mlc-models/Qwen3-VL-2B-Instruct/" + mlc_model_path = "../mlc-models/mlc-qwen/" + quantization = "q0f16" + conv_template = "qwen3_vl" + + # ensure output dir exists + if not os.path.exists(mlc_model_path): + os.makedirs(mlc_model_path) + + # 1. Gen Config + cmd_gen_config = f"python -m mlc_llm gen_config {local_model_path} --quantization {quantization} --conv-template {conv_template} -o {mlc_model_path}" + if run_command(cmd_gen_config) != 0: + print("Gen Config Failed") + return + + # 2. Convert Weight + cmd_convert = f"mlc_llm convert_weight {local_model_path} --quantization {quantization} -o {mlc_model_path}" + if run_command(cmd_convert) != 0: + print("Convert Weight Failed") + return + + print("Verification Successful!") + +if __name__ == "__main__": + verify() From 71ea960c37c0c46dd000c1376545c756be1285e4 Mon Sep 17 00:00:00 2001 From: Yash Agarwal Date: Mon, 8 Dec 2025 14:17:07 -0500 Subject: [PATCH 3/4] Implement forwars for qwen2vl --- python/mlc_llm/model/qwen3_vl/qwen2_vl.py | 184 ++++++++-------- python/mlc_llm/model/qwen3_vl/rope_utils.py | 66 ++++++ tests/python/model/test_qwen2_vl_others.py | 229 ++++++++++++++++++++ tests/python/model/test_qwen2_vl_vision.py | 153 +++++++++++++ 4 files changed, 534 insertions(+), 98 deletions(-) create mode 100644 python/mlc_llm/model/qwen3_vl/rope_utils.py create mode 100644 tests/python/model/test_qwen2_vl_others.py create mode 100644 tests/python/model/test_qwen2_vl_vision.py diff --git a/python/mlc_llm/model/qwen3_vl/qwen2_vl.py b/python/mlc_llm/model/qwen3_vl/qwen2_vl.py index 6b35fceedd..0186562d24 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen2_vl.py +++ b/python/mlc_llm/model/qwen3_vl/qwen2_vl.py @@ -1,5 +1,17 @@ +import math from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op +from tvm.relax.frontend.nn.op import wrap_nested +from tvm import relax as rx + +def _wrap_op(f, *args): + args = [x._expr if isinstance(x, Tensor) else x for x in args] + return wrap_nested(f(*args), name=f.__name__) + +def op_cos(x): return _wrap_op(rx.op.cos, x) +def op_sin(x): return _wrap_op(rx.op.sin, x) +def op_power(a, b): return _wrap_op(rx.op.power, a, b) + from typing import Optional, Tuple @@ -19,22 +31,17 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] - # TODO - i am assuming tvm has the same conv3d as pytorch self.proj = nn.Conv3D(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: Tensor) -> Tensor: - - ''' - TODO - translate pytorch to tvm - ''' - - raise NotImplementedError - # target_dtype = self.proj.weight.dtype - # hidden_states = hidden_states.view( - # -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - # ) - # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - # return hidden_states + hidden_states = op.reshape( + hidden_states, + (-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size) + ) + + hidden_states = self.proj(hidden_states) + hidden_states = op.reshape(hidden_states, (-1, self.embed_dim)) + return hidden_states @@ -48,26 +55,26 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.theta = theta def forward(self, seqlen: int) -> Tensor: - # TODO - assuming op.arange syntax == torch.arange, changed dtype to the string literal float32, idk how tvm does dtypes - self.inv_freq = 1.0 / (self.theta ** (op.arange(0, self.dim, 2, dtype="float32") / self.dim)) - pass - - # TODO - translate pytorch to tvm - raise NotImplementedError - # seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - # freqs = torch.outer(seq, self.inv_freq) - # return freqs + theta_const = rx.const(self.theta, "float32") + inv_freq = op.divide(Tensor(_expr=rx.const(1.0, "float32")), op_power(theta_const, (op.arange(0, self.dim, 2, dtype="float32") / self.dim))) + + seq = op.arange(0, seqlen, dtype="float32") + + seq = op.reshape(seq, (seqlen, 1)) + inv_freq = op.reshape(inv_freq, (1, self.dim // 2)) + + freqs = seq * inv_freq + return freqs class VisionAttention(nn.Module): - # fyi this expects a Qwen2VLVisionConfig def __init__(self, config) -> None: super().__init__() self.dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.dim // self.num_heads - self.num_key_value_groups = 1 # needed for eager attention + self.num_key_value_groups = 1 self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 @@ -84,69 +91,60 @@ def forward( **kwargs, ) -> Tensor: - # TODO - translate pytorch to tvm - - raise NotImplementedError - - # seq_length = hidden_states.shape[0] - # query_states, key_states, value_states = ( - # self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) - # ) - # cos, sin = position_embeddings - # query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) - - # query_states = query_states.transpose(0, 1).unsqueeze(0) - # key_states = key_states.transpose(0, 1).unsqueeze(0) - # value_states = value_states.transpose(0, 1).unsqueeze(0) - - # attention_interface: Callable = eager_attention_forward - # if self.config._attn_implementation != "eager": - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - # if self.config._attn_implementation == "flash_attention_2": - # # Flash Attention 2: Use cu_seqlens for variable length attention - # max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - # attn_output, _ = attention_interface( - # self, - # query_states, - # key_states, - # value_states, - # attention_mask=None, - # scaling=self.scaling, - # dropout=0.0 if not self.training else self.attention_dropout, - # cu_seq_lens_q=cu_seqlens, - # cu_seq_lens_k=cu_seqlens, - # max_length_q=max_seqlen, - # max_length_k=max_seqlen, - # is_causal=False, - # **kwargs, - # ) - # else: - # # Other implementations: Process each chunk separately - # lengths = cu_seqlens[1:] - cu_seqlens[:-1] - # splits = [ - # torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) - # ] - - # attn_outputs = [ - # attention_interface( - # self, - # q, - # k, - # v, - # attention_mask=None, - # scaling=self.scaling, - # dropout=0.0 if not self.training else self.attention_dropout, - # is_causal=False, - # **kwargs, - # )[0] - # for q, k, v in zip(*splits) - # ] - # attn_output = torch.cat(attn_outputs, dim=1) - - # attn_output = attn_output.reshape(seq_length, -1).contiguous() - # attn_output = self.proj(attn_output) - # return attn_output + + b, s, _ = hidden_states.shape + qkv = self.qkv(hidden_states) + qkv = op.reshape(qkv, (b, s, 3, self.num_heads, self.head_dim)) + + q, k, v = op.split(qkv, 3, axis=2) + q = op.squeeze(q, axis=2) + k = op.squeeze(k, axis=2) + v = op.squeeze(v, axis=2) + + # Apply RoPE if provided + if rotary_pos_emb is not None: + freqs = rotary_pos_emb + cos = op_cos(freqs) + sin = op_sin(freqs) + + # Reshape for broadcasting: (1, s, 1, d/2) + cos = op.reshape(cos, (1, s, 1, self.head_dim // 2)) + sin = op.reshape(sin, (1, s, 1, self.head_dim // 2)) + + # Use repeat to match head_dim + cos = op.concat([cos, cos], dim=-1) + sin = op.concat([sin, sin], dim=-1) + + def rotate_half(x): + x1, x2 = op.split(x, 2, axis=-1) # split last dim + return op.concat([op.negative(x2), x1], dim=-1) + + q = (q * cos) + (rotate_half(q) * sin) + k = (k * cos) + (rotate_half(k) * sin) + + # Attention + q = op.permute_dims(q, (0, 2, 1, 3)) # (b, h, s, d) + k = op.permute_dims(k, (0, 2, 1, 3)) # (b, h, s, d) + v = op.permute_dims(v, (0, 2, 1, 3)) # (b, h, s, d) + + # k.T -> (b, h, d, s) + k_t = op.permute_dims(k, (0, 1, 3, 2)) + + attn_weights = op.matmul(q, k_t) # (b, h, s, s) + attn_weights = attn_weights * self.scaling + + attn_weights = op.softmax(attn_weights, axis=-1) + + attn_output = op.matmul(attn_weights, v) # (b, h, s, d) + + # Transpose back: (b, s, h, d) + attn_output = op.permute_dims(attn_output, (0, 2, 1, 3)) + + # Reshape to (b, s, dim) + attn_output = op.reshape(attn_output, (b, s, self.dim)) + + attn_output = self.proj(attn_output) + return attn_output class Qwen2RMSNorm(nn.Module): @@ -156,21 +154,11 @@ def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ super().__init__() - # fyi assuming nn.Parameter is a thing - - # do i need to have nn.Parameter, or can it just be op.ones? self.weight = nn.Parameter((hidden_size,), dtype="float32") self.variance_epsilon = eps def forward(self, hidden_states: Tensor) -> Tensor: - # TODO - translate pytorch to tvm - - raise NotImplementedError - # input_dtype = hidden_states.dtype - # hidden_states = hidden_states.to("float32") - # variance = hidden_states.pow(2).mean(-1, keepdim=True) - # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # return self.weight * hidden_states.to(input_dtype) + return op.rms_norm(hidden_states, self.weight, axes=-1, epsilon=self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/python/mlc_llm/model/qwen3_vl/rope_utils.py b/python/mlc_llm/model/qwen3_vl/rope_utils.py new file mode 100644 index 0000000000..1f24e8c986 --- /dev/null +++ b/python/mlc_llm/model/qwen3_vl/rope_utils.py @@ -0,0 +1,66 @@ + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed \ No newline at end of file diff --git a/tests/python/model/test_qwen2_vl_others.py b/tests/python/model/test_qwen2_vl_others.py new file mode 100644 index 0000000000..c3fe0a2e74 --- /dev/null +++ b/tests/python/model/test_qwen2_vl_others.py @@ -0,0 +1,229 @@ + +import torch +import tvm +import tvm.testing +from tvm import runtime +from tvm.relax.frontend import nn +from tvm import relax +import numpy as np + +from mlc_llm.model.qwen3_vl.qwen2_vl import VisionAttention, Qwen2RMSNorm + +def test_qwen2_rms_norm(): + hidden_size = 32 + eps = 1e-6 + + # TVM + class TVMNorm(nn.Module): + def __init__(self): + super().__init__() + self.model = Qwen2RMSNorm(hidden_size, eps=eps) + + def forward(self, x): + return self.model(x) + + # PyTorch + class PyTorchNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + input_data = np.random.randn(2, 10, hidden_size).astype("float32") + + tvm_model = TVMNorm() + # Export + mod, params = tvm_model.export_tvm( + spec={"forward": {"x": nn.spec.Tensor(input_data.shape, "float32")}} + ) + + # Initialize PyTorch + torch_model = PyTorchNorm(hidden_size, eps) + + # Copy weights (ones) + weight_np = torch_model.weight.detach().numpy() + param_dict = dict(params) + param_dict["model.weight"] = runtime.tensor(weight_np) + + # Build + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + # Run TVM + tvm_out = vm["forward"](runtime.tensor(input_data), *param_dict.values()) + + # Run PyTorch + torch_out = torch_model(torch.from_numpy(input_data)).detach().numpy() + + np.testing.assert_allclose(torch_out, tvm_out.numpy(), rtol=1e-5, atol=1e-5) + print("Qwen2RMSNorm test passed!") + +def test_vision_attention(): + # Mock config + class Config: + hidden_size = 64 + num_heads = 4 + + config = Config() + + # --- Helper Definitions (Exact copy from rope_utils.py for the test) --- + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + # ----------------------------------------------------- + + # TVM Wrapper + class TVMAttn(nn.Module): + def __init__(self): + super().__init__() + self.model = VisionAttention(config) + + # We match signature of VisionAttention.forward + def forward(self, x, cu_seqlens, rotary_pos_emb): + return self.model(x, cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # PyTorch Reference + class PyTorchAttn(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.qkv = torch.nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = torch.nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + + def forward(self, hidden_states, rotary_pos_emb=None): + # rotary_pos_emb here corresponds to 'freqs' passed to the TVM model + # But apply_rotary_pos_emb_vision needs cos, sin. + # So in this reference model, we will compute cos/sin from the passed freqs + # to mimic exactly what happens. + + b, s, _ = hidden_states.shape + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(b, s, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + # apply_rotary_pos_emb_vision expects q, k in shape (batch, seqlen, num_heads, head_dim) + # which matches current q, k shape. + + if rotary_pos_emb is not None: + # rotary_pos_emb is 'freqs' (seqlen, head_dim) or (seqlen, head_dim/2) ? + # VisionRotaryEmbedding returns freqs as implicit angles. + # In PyTorch VisionRotaryEmbedding (from qwen2_vl.py comments): + # freqs = torch.outer(seq, self.inv_freq) + # Shape: (seqlen, dim/2). + # But typically for RoPE we duplicate them to (seqlen, dim) for cos/sin? + # Or apply_rotary_pos_emb_vision handles it? + + # Look at apply_rotary_pos_emb_vision: + # q_embed = (q * cos) + ... + # q: (b, s, h, d) + # cos: unsqueezed to (s, 1, d). + # So cos must have last dim 'd' (head_dim). + + # But freqs from outer product is head_dim/2. + # So we must repeat/cat to get full head_dim. + + freqs = rotary_pos_emb + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + q = q.transpose(1, 2) # (b, h, s, d) + k = k.transpose(1, 2) # (b, h, s, d) + v = v.transpose(1, 2) # (b, h, s, d) + + attn = (q @ k.transpose(-2, -1)) * self.scaling + attn = attn.softmax(dim=-1) + out = attn @ v + + out = out.transpose(1, 2).reshape(b, s, self.dim) + out = self.proj(out) + return out + + # Input Setup + b, s, d = 1, 10, config.hidden_size + head_dim = d // config.num_heads + input_data = np.random.randn(b, s, d).astype("float32") + cu_seqlens_np = np.array([0, s], dtype="int32") + + # Generate constant freqs for testing RoPE + # shape: (s, head_dim // 2) + freqs_np = np.random.randn(s, head_dim // 2).astype("float32") + + tvm_model = TVMAttn() + torch_model = PyTorchAttn(config) + + # Export TVM + mod, params = tvm_model.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor([b, s, d], "float32"), + "cu_seqlens": nn.spec.Tensor([2], "int32"), + "rotary_pos_emb": nn.spec.Tensor([s, head_dim // 2], "float32") + } + } + ) + + # Sync weights + param_dict = dict(params) + qkv_w = torch_model.qkv.weight.detach().numpy() + qkv_b = torch_model.qkv.bias.detach().numpy() + proj_w = torch_model.proj.weight.detach().numpy() + proj_b = torch_model.proj.bias.detach().numpy() + + param_dict["model.qkv.weight"] = runtime.tensor(qkv_w) + param_dict["model.qkv.bias"] = runtime.tensor(qkv_b) + param_dict["model.proj.weight"] = runtime.tensor(proj_w) + param_dict["model.proj.bias"] = runtime.tensor(proj_b) + + # Build + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + # Run PyTorch + torch_in = torch.from_numpy(input_data) + rotary_in = torch.from_numpy(freqs_np) + with torch.no_grad(): + torch_out = torch_model(torch_in, rotary_pos_emb=rotary_in).numpy() + + # Run TVM + tvm_out = vm["forward"]( + runtime.tensor(input_data), + runtime.tensor(cu_seqlens_np), + runtime.tensor(freqs_np), + *param_dict.values() + ) + + np.testing.assert_allclose(torch_out, tvm_out.numpy(), rtol=1e-5, atol=1e-5) + print("VisionAttention test passed!") + + +if __name__ == "__main__": + test_qwen2_rms_norm() + test_vision_attention() diff --git a/tests/python/model/test_qwen2_vl_vision.py b/tests/python/model/test_qwen2_vl_vision.py new file mode 100644 index 0000000000..4c284bc6af --- /dev/null +++ b/tests/python/model/test_qwen2_vl_vision.py @@ -0,0 +1,153 @@ + +import torch +import tvm +import tvm.testing +from tvm import runtime +from tvm.relax.frontend import nn +from tvm import relax +import numpy as np + +from mlc_llm.model.qwen3_vl.qwen2_vl import PatchEmbed, VisionRotaryEmbedding + +def test_patch_embed(): + # Configuration + patch_size = 14 + temporal_patch_size = 2 + in_channels = 3 + embed_dim = 32 # Small dim for testing + + # TVM Model + class TVMPatchEmbed(nn.Module): + def __init__(self): + super().__init__() + self.model = PatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim + ) + + def forward(self, x): + return self.model(x) + + # PyTorch Reference + class PyTorchPatchEmbed(torch.nn.Module): + def __init__(self): + super().__init__() + target_kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = torch.nn.Conv3d( + in_channels, + embed_dim, + kernel_size=target_kernel_size, + stride=target_kernel_size, + bias=False + ) + self.in_channels = in_channels + self.temporal_patch_size = temporal_patch_size + self.patch_size = patch_size + self.embed_dim = embed_dim + + def forward(self, hidden_states): + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + # Inputs + num_patches = 5 + input_flat_size = in_channels * temporal_patch_size * patch_size * patch_size + input_data_np = np.random.randn(num_patches, input_flat_size).astype("float32") + + # Setup PyTorch + torch_model = PyTorchPatchEmbed() + torch_input = torch.from_numpy(input_data_np) + + # Setup TVM + tvm_model = TVMPatchEmbed() + + # Run PyTorch + with torch.no_grad(): + torch_output = torch_model(torch_input).numpy() + + # Run TVM + mod, params = tvm_model.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor([num_patches, input_flat_size], "float32") + } + } + ) + + param_dict = dict(params) + w_key = list(param_dict.keys())[0] + weight_np = torch_model.proj.weight.detach().numpy() + param_dict[w_key] = runtime.tensor(weight_np) + + # Build + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + # Run + tvm_output = vm["forward"](runtime.tensor(input_data_np), *param_dict.values()) + + # Verify + np.testing.assert_allclose(torch_output, tvm_output.numpy(), rtol=1e-5, atol=1e-5) + print("PatchEmbed test passed!") + + +def test_vision_rotary_embedding(): + dim = 32 + theta = 10000.0 + seqlen = 10 + + # TVM + class TVMRotary(nn.Module): + def __init__(self): + super().__init__() + self.model = VisionRotaryEmbedding(dim, theta) + + def forward(self, seqlen: int): + return self.model(seqlen) + + # PyTorch + class PyTorchRotary(torch.nn.Module): + def __init__(self, dim, theta=10000.0): + super().__init__() + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + def forward(self, seqlen): + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + torch_model = PyTorchRotary(dim, theta) + torch_out = torch_model(seqlen).numpy() + + tvm_model = TVMRotary() + + # We pass seqlen as int. + mod, params = tvm_model.export_tvm( + spec={ + "forward": { + "seqlen": int + } + } + ) + + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + # Run. + # For int spec, we generally pass ShapeTuple for the VM to unpack if it expects scalar args mapped from shape. + tvm_out = vm["forward"](tvm.runtime.ShapeTuple([seqlen])) + + np.testing.assert_allclose(torch_out, tvm_out.numpy(), rtol=1e-5, atol=1e-5) + print("VisionRotaryEmbedding test passed!") + +if __name__ == "__main__": + test_patch_embed() + test_vision_rotary_embedding() From 6b34e75fbccea68884cabe1865a415e9920e3592 Mon Sep 17 00:00:00 2001 From: Yash Agarwal Date: Mon, 8 Dec 2025 14:23:14 -0500 Subject: [PATCH 4/4] Implement forward for qwen2.5vl --- python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py | 27 +-- tests/python/model/test_qwen2_5_vl.py | 214 +++++++++++++++++++ 2 files changed, 225 insertions(+), 16 deletions(-) create mode 100644 tests/python/model/test_qwen2_5_vl.py diff --git a/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py b/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py index 44c89e8f6a..5a0f34daee 100644 --- a/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py +++ b/python/mlc_llm/model/qwen3_vl/qwen_2_5_vl.py @@ -22,10 +22,8 @@ def __init__(self, config, bias: bool = False): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - # TODO - translate pytorch to tvm - pass - # return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, hidden_state: Tensor) -> Tensor: + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class Qwen2_5_VLVisionBlock(nn.Module): @@ -44,18 +42,15 @@ def forward( position_embeddings: Optional[tuple[Tensor, Tensor]] = None, **kwargs, ) -> Tensor: - # TODO - translate pytorch to tvm - - pass - # hidden_states = hidden_states + self.attn( - # self.norm1(hidden_states), - # cu_seqlens=cu_seqlens, - # rotary_pos_emb=rotary_pos_emb, - # position_embeddings=position_embeddings, - # **kwargs, - # ) - # hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - # return hidden_states + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states class Qwen2_5_VLModel(Qwen2VLModel): diff --git a/tests/python/model/test_qwen2_5_vl.py b/tests/python/model/test_qwen2_5_vl.py new file mode 100644 index 0000000000..aae6ba6754 --- /dev/null +++ b/tests/python/model/test_qwen2_5_vl.py @@ -0,0 +1,214 @@ + +import torch +import tvm +import tvm.testing +from tvm import runtime +from tvm.relax.frontend import nn +from tvm import relax +import numpy as np + +from mlc_llm.model.qwen3_vl.qwen_2_5_vl import Qwen2_5_VLMLP, Qwen2_5_VLVisionBlock + +# Helper for RoPE (needed for Attention inside Block) +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) + +class Config: + hidden_size = 64 + intermediate_size = 128 + num_heads = 4 + hidden_act = "silu" + +config = Config() + +def test_mlp(): + # TVM + class TVMMLP(nn.Module): + def __init__(self): + super().__init__() + self.model = Qwen2_5_VLMLP(config, bias=True) + + def forward(self, x): + return self.model(x) + + # PyTorch + class PyTorchMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.up_proj = torch.nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.down_proj = torch.nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + self.act_fn = torch.nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + # Inputs + b, s, d = 1, 10, config.hidden_size + input_data = np.random.randn(b, s, d).astype("float32") + + tvm_model = TVMMLP() + torch_model = PyTorchMLP() + + mod, params = tvm_model.export_tvm( + spec={"forward": {"x": nn.spec.Tensor([b, s, d], "float32")}} + ) + + # Sync weights + param_dict = dict(params) + param_dict["model.gate_proj.weight"] = runtime.tensor(torch_model.gate_proj.weight.detach().numpy()) + param_dict["model.gate_proj.bias"] = runtime.tensor(torch_model.gate_proj.bias.detach().numpy()) + param_dict["model.up_proj.weight"] = runtime.tensor(torch_model.up_proj.weight.detach().numpy()) + param_dict["model.up_proj.bias"] = runtime.tensor(torch_model.up_proj.bias.detach().numpy()) + param_dict["model.down_proj.weight"] = runtime.tensor(torch_model.down_proj.weight.detach().numpy()) + param_dict["model.down_proj.bias"] = runtime.tensor(torch_model.down_proj.bias.detach().numpy()) + + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + tvm_out = vm["forward"](runtime.tensor(input_data), *param_dict.values()) + torch_out = torch_model(torch.from_numpy(input_data)).detach().numpy() + + np.testing.assert_allclose(torch_out, tvm_out.numpy(), rtol=1e-5, atol=1e-5) + print("Qwen2_5_VLMLP test passed!") + +def test_vision_block(): + # TVM + class TVMBlock(nn.Module): + def __init__(self): + super().__init__() + self.model = Qwen2_5_VLVisionBlock(config) + + def forward(self, x, cu_seqlens, rotary_pos_emb): + return self.model(x, cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # PyTorch Reference + class PyTorchBlock(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm1 = torch.nn.RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = torch.nn.RMSNorm(config.hidden_size, eps=1e-6) + + # Attn + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.qkv = torch.nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = torch.nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + + # MLP + self.gate_proj = torch.nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.up_proj = torch.nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.down_proj = torch.nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + self.act_fn = torch.nn.SiLU() + + def forward(self, hidden_states, rotary_pos_emb=None): + # Attention block + residual = hidden_states + hidden_states = self.norm1(hidden_states) + + b, s, _ = hidden_states.shape + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(b, s, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + if rotary_pos_emb is not None: + freqs = rotary_pos_emb + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn = (q @ k.transpose(-2, -1)) * self.scaling + attn = attn.softmax(dim=-1) + out = attn @ v + out = out.transpose(1, 2).reshape(b, s, self.dim) + out = self.proj(out) + + hidden_states = residual + out + + # MLP Block + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + hidden_states = residual + hidden_states + + return hidden_states + + # Inputs + b, s, d = 1, 10, config.hidden_size + head_dim = d // config.num_heads + input_data = np.random.randn(b, s, d).astype("float32") + cu_seqlens_np = np.array([0, s], dtype="int32") + freqs_np = np.random.randn(s, head_dim // 2).astype("float32") + + tvm_model = TVMBlock() + torch_model = PyTorchBlock() + + mod, params = tvm_model.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor([b, s, d], "float32"), + "cu_seqlens": nn.spec.Tensor([2], "int32"), + "rotary_pos_emb": nn.spec.Tensor([s, head_dim // 2], "float32") + } + } + ) + + # Sync weights + param_dict = dict(params) + + def sync_linear(tvm_name, torch_module): + param_dict[f"{tvm_name}.weight"] = runtime.tensor(torch_module.weight.detach().numpy()) + if torch_module.bias is not None: + param_dict[f"{tvm_name}.bias"] = runtime.tensor(torch_module.bias.detach().numpy()) + + + # Wait, Qwen2RMSNorm uses nn.Parameter name 'weight'. + param_dict["model.norm1.weight"] = runtime.tensor(torch_model.norm1.weight.detach().numpy()) + param_dict["model.norm2.weight"] = runtime.tensor(torch_model.norm2.weight.detach().numpy()) + + sync_linear("model.attn.qkv", torch_model.qkv) + sync_linear("model.attn.proj", torch_model.proj) + + sync_linear("model.mlp.gate_proj", torch_model.gate_proj) + sync_linear("model.mlp.up_proj", torch_model.up_proj) + sync_linear("model.mlp.down_proj", torch_model.down_proj) + + ex = tvm.relax.build(mod, target="llvm") + vm = tvm.relax.VirtualMachine(ex, tvm.cpu()) + + tvm_out = vm["forward"]( + runtime.tensor(input_data), + runtime.tensor(cu_seqlens_np), + runtime.tensor(freqs_np), + *param_dict.values() + ) + + torch_in = torch.from_numpy(input_data) + rotary_in = torch.from_numpy(freqs_np) + with torch.no_grad(): + torch_out = torch_model(torch_in, rotary_pos_emb=rotary_in).numpy() + + np.testing.assert_allclose(torch_out, tvm_out.numpy(), rtol=1e-5, atol=1e-5) + print("Qwen2_5_VLVisionBlock test passed!") + +if __name__ == "__main__": + test_mlp() + test_vision_block()