Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, model):

def forward(self, pixel_values):
vision_embeds = self.model.extract_feature(pixel_values)
# Reshape from [num_patches, 256, hidden_dim] -> [1, num_patches*256, head_dim]
# To enable prefill chunking for num_patches > 1
vision_embeds = vision_embeds.reshape(1, -1, vision_embeds.shape[-1])
return vision_embeds


Expand All @@ -35,14 +38,22 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
# TODO: Find a better way to decide which token value to use
image_context_token = (
constants.INTERN_3_5_IMG_CONTEXT_TOKEN
if "Qwen3" in self.config.architectures[0]
else constants.INTERN_IMG_CONTEXT_TOKEN
)
selected = image_input_ids == image_context_token
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
inputs_embeds = inputs_embeds.reshape(B, N, C)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
Expand Down Expand Up @@ -84,12 +95,13 @@ def get_specializations(
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")

per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
vision_size = int(num_patches * per_patch_embed_size)
vision_size = int(batch_size * num_patches * per_patch_embed_size)
vision = [
{
"batch_size": batch_size,
"num_patches": num_patches,
"img_size": img_size,
"batched_num_patches": batch_size * num_patches,
}
]
lang = [
Expand Down Expand Up @@ -126,8 +138,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "vision_size"}
vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"}
lang_dynamic_axes["vision_embeds"] = {1: "vision_size"}
vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
for i in range(self.language_model.config.num_hidden_layers):
Expand Down Expand Up @@ -182,16 +194,16 @@ def get_dummy_inputs(self, kv_offload: bool = False):
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
computed_feature_size,
1,
computed_feature_size * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
self.language_model.config.hidden_size,
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.INTERN_NUM_PATCHES,
constants.INTERN_NUM_PATCHES * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
constants.INTERN_NUM_CHANNELS,
img_size,
img_size,
Expand Down Expand Up @@ -237,14 +249,22 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val
vision_embeds = self.extract_feature(pixel_values)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN
# TODO: Find a better way to decide which token value to use
image_context_token = (
constants.INTERN_3_5_IMG_CONTEXT_TOKEN
if "Qwen3" in self.config.architectures[0]
else constants.INTERN_IMG_CONTEXT_TOKEN
)
selected = image_input_ids == image_context_token
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
inputs_embeds = inputs_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
Expand Down
20 changes: 20 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3Attention,
Qwen3DecoderLayer,
Qwen3ForCausalLM,
Qwen3Model,
Qwen3RMSNorm,
)
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Expand Down Expand Up @@ -303,6 +310,12 @@
QEffQwen2ForCausalLM,
QEffQwen2Model,
)
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
QEffQwen3Attention,
QEffQwen3DecoderLayer,
QEffQwen3ForCausalLM,
QEffQwen3Model,
)
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
Expand Down Expand Up @@ -335,6 +348,7 @@ class CustomOpsTransform(ModuleMappingTransform):
MixtralRMSNorm: CustomRMSNormAIC,
Phi3RMSNorm: CustomRMSNormAIC,
Qwen2RMSNorm: CustomRMSNormAIC,
Qwen3RMSNorm: CustomRMSNormAIC,
MllamaTextRMSNorm: CustomRMSNormAIC,
GraniteRMSNorm: CustomRMSNormAIC,
GraniteMoeRMSNorm: CustomRMSNormAIC,
Expand Down Expand Up @@ -452,6 +466,11 @@ class KVCacheTransform(ModuleMappingTransform):
Qwen2DecoderLayer: QEffQwen2DecoderLayer,
Qwen2Model: QEffQwen2Model,
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
# Qwen3
Qwen3Attention: QEffQwen3Attention,
Qwen3DecoderLayer: QEffQwen3DecoderLayer,
Qwen3Model: QEffQwen3Model,
Qwen3ForCausalLM: QEffQwen3ForCausalLM,
# Starcoder2
Starcoder2Attention: QEffStarcoder2Attention,
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
Expand Down Expand Up @@ -498,6 +517,7 @@ class SpDTransform:
# Llama
QEffLlamaForCausalLM,
QEffQwen2ForCausalLM,
QEffQwen3ForCausalLM,
}

@classmethod
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading