diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index b6fb9fd38..38d0fe167 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -21,6 +21,9 @@ def __init__(self, model): def forward(self, pixel_values): vision_embeds = self.model.extract_feature(pixel_values) + # Reshape from [num_patches, 256, hidden_dim] -> [1, num_patches*256, head_dim] + # To enable prefill chunking for num_patches > 1 + vision_embeds = vision_embeds.reshape(1, -1, vision_embeds.shape[-1]) return vision_embeds @@ -35,14 +38,22 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) + input_embeds = input_embeds.reshape(B * N, C) image_input_ids = input_ids.reshape(B * N) - selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN + # TODO: Find a better way to decide which token value to use + image_context_token = ( + constants.INTERN_3_5_IMG_CONTEXT_TOKEN + if "Qwen3" in self.config.architectures[0] + else constants.INTERN_IMG_CONTEXT_TOKEN + ) + selected = image_input_ids == image_context_token indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) + inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) @@ -84,12 +95,13 @@ def get_specializations( raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.") per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2 - vision_size = int(num_patches * per_patch_embed_size) + vision_size = int(batch_size * num_patches * per_patch_embed_size) vision = [ { "batch_size": batch_size, "num_patches": num_patches, "img_size": img_size, + "batched_num_patches": batch_size * num_patches, } ] lang = [ @@ -126,8 +138,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "vision_size"} - vision_dynamic_axes["pixel_values"] = {0: "num_patches", 2: "img_size", 3: "img_size"} + lang_dynamic_axes["vision_embeds"] = {1: "vision_size"} + vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"} pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"} for i in range(self.language_model.config.num_hidden_layers): @@ -182,8 +194,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( - constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, - computed_feature_size, + 1, + computed_feature_size * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, self.language_model.config.hidden_size, ) inputs_shapes["position_ids"] = ( @@ -191,7 +203,7 @@ def get_dummy_inputs(self, kv_offload: bool = False): constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) inputs_shapes["pixel_values"] = ( - constants.INTERN_NUM_PATCHES, + constants.INTERN_NUM_PATCHES * constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.INTERN_NUM_CHANNELS, img_size, img_size, @@ -237,14 +249,22 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) + input_embeds = input_embeds.reshape(B * N, C) image_input_ids = input_ids.reshape(B * N) - selected = image_input_ids == constants.INTERN_IMG_CONTEXT_TOKEN + # TODO: Find a better way to decide which token value to use + image_context_token = ( + constants.INTERN_3_5_IMG_CONTEXT_TOKEN + if "Qwen3" in self.config.architectures[0] + else constants.INTERN_IMG_CONTEXT_TOKEN + ) + selected = image_input_ids == image_context_token indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1 indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1] image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) + inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..9c4e93778 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -136,6 +136,13 @@ Qwen2Model, Qwen2RMSNorm, ) +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3RMSNorm, +) from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, Starcoder2DecoderLayer, @@ -303,6 +310,12 @@ QEffQwen2ForCausalLM, QEffQwen2Model, ) +from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( + QEffQwen3Attention, + QEffQwen3DecoderLayer, + QEffQwen3ForCausalLM, + QEffQwen3Model, +) from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, QEFFStarcoder2DecoderLayer, @@ -335,6 +348,7 @@ class CustomOpsTransform(ModuleMappingTransform): MixtralRMSNorm: CustomRMSNormAIC, Phi3RMSNorm: CustomRMSNormAIC, Qwen2RMSNorm: CustomRMSNormAIC, + Qwen3RMSNorm: CustomRMSNormAIC, MllamaTextRMSNorm: CustomRMSNormAIC, GraniteRMSNorm: CustomRMSNormAIC, GraniteMoeRMSNorm: CustomRMSNormAIC, @@ -452,6 +466,11 @@ class KVCacheTransform(ModuleMappingTransform): Qwen2DecoderLayer: QEffQwen2DecoderLayer, Qwen2Model: QEffQwen2Model, Qwen2ForCausalLM: QEffQwen2ForCausalLM, + # Qwen3 + Qwen3Attention: QEffQwen3Attention, + Qwen3DecoderLayer: QEffQwen3DecoderLayer, + Qwen3Model: QEffQwen3Model, + Qwen3ForCausalLM: QEffQwen3ForCausalLM, # Starcoder2 Starcoder2Attention: QEffStarcoder2Attention, Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, @@ -498,6 +517,7 @@ class SpDTransform: # Llama QEffLlamaForCausalLM, QEffQwen2ForCausalLM, + QEffQwen3ForCausalLM, } @classmethod diff --git a/QEfficient/transformers/models/qwen3/__init__.py b/QEfficient/transformers/models/qwen3/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/qwen3/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py new file mode 100644 index 000000000..03fd478de --- /dev/null +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -0,0 +1,427 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Qwen3 model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3Attention, + Qwen3Config, + Qwen3DecoderLayer, + Qwen3ForCausalLM, + Qwen3Model, + Qwen3RotaryEmbedding, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffQwen3RotaryEmbedding(Qwen3RotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: Qwen3Config, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class QEffQwen3Attention(Qwen3Attention): + """ + Copied from Qwen3Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + """ + + def __qeff_init__(self): + self.rotary_emb = QEffQwen3RotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffQwen3DecoderLayer(Qwen3DecoderLayer): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffQwen3Model(Qwen3Model): + """ + Copied from Qwen3Model: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update causal attention mask + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffQwen3ForCausalLM(Qwen3ForCausalLM): + """ + Copied from Qwen3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 8228b7c0e..f8552b169 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -97,7 +97,10 @@ def get_models_dir(): INTERN_CTX_LEN = 4096 INTERN_PREFILL_SEQ_LEN = INTERN_CTX_LEN - 256 # 4096-256 INTERN_NUM_CHANNELS = 3 + INTERN_IMG_CONTEXT_TOKEN = 151667 +# Specific to InternVL3_5 series, same token won't work for InternVL2_5 series +INTERN_3_5_IMG_CONTEXT_TOKEN = 151671 # Granite Vision Constants # Fixing the feature size with reference to ibm-granite/granite-vision-3.2-2b diff --git a/QEfficient/utils/test_utils.py b/QEfficient/utils/test_utils.py index 3d70ac4f3..eb68d0d41 100644 --- a/QEfficient/utils/test_utils.py +++ b/QEfficient/utils/test_utils.py @@ -132,24 +132,29 @@ def __call__( IMG_CONTEXT_TOKEN="", verbose=False, ) -> str: - if history is None and pixel_values is not None and "" not in question: - question = "\n" + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.model.img_context_token_id = img_context_token_id - messages.append([roles[0], question]) - messages.append([roles[1], None]) - query = self.get_prompt(messages) if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f"dynamic ViT batch size: {image_bs}") - for num_patches in num_patches_list: + queries = [] + for idx, num_patches in enumerate(num_patches_list): + query = question[idx] + if history is None and pixel_values is not None and "" not in query: + query = "\n" + query + template = messages.copy() + template.append([roles[0], query]) + template.append([roles[1], None]) + query = self.get_prompt(template) + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN query = query.replace("", image_tokens, 1) - return query + queries.append(query) + return queries class ModelConfig: diff --git a/examples/intern_example/internvl_inference.py b/examples/intern_example/internvl_inference.py index eba8c10d5..3d8027ba4 100644 --- a/examples/intern_example/internvl_inference.py +++ b/examples/intern_example/internvl_inference.py @@ -10,168 +10,24 @@ import requests import torch -import torch.nn as nn -import torchvision.transforms as T from PIL import Image -from torchvision.transforms.functional import InterpolationMode from transformers import AutoTokenizer, TextStreamer from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils.logging_utils import logger - -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) - - -# Process the input messages to generate prompt for the model. -def get_prompt(messages) -> str: - """Get the prompt for generation.""" - ## Chat template used for InternVL - system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" - sep = "<|im_end|>\n" - - ret = system_prompt + sep - for role, message in messages: - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + message + sep - else: - ret += role - return ret - - -# Processor class for InternVL models -class InternProcessor: - """ - InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. - The methods used here are borrowed from the original InternVL modelling files. - "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" - """ - - def __init__(self, model: nn.Module, tokenizer): - self.model = model - image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size - patch_size = self.model.config.vision_config.patch_size - self.template = model.config.template - self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) - self.tokenizer = tokenizer - - def build_transform(self, input_size): - MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD), - ] - ) - return transform - - def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): - best_ratio_diff = float("inf") - best_ratio = (1, 1) - area = width * height - for ratio in target_ratios: - target_aspect_ratio = ratio[0] / ratio[1] - ratio_diff = abs(aspect_ratio - target_aspect_ratio) - if ratio_diff < best_ratio_diff: - best_ratio_diff = ratio_diff - best_ratio = ratio - elif ratio_diff == best_ratio_diff: - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: - best_ratio = ratio - return best_ratio - - def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): - orig_width, orig_height = image.size - aspect_ratio = orig_width / orig_height - # calculate the existing image aspect ratio - target_ratios = set( - (i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) - if i * j <= max_num and i * j >= min_num - ) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - # find the closest aspect ratio to the target - target_aspect_ratio = self.find_closest_aspect_ratio( - aspect_ratio, target_ratios, orig_width, orig_height, image_size - ) - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - return processed_images - - def load_image(self, image, input_size=448, max_num=12): - transform = self.build_transform(input_size=input_size) - images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) - pixel_values = [transform(image) for image in images] - pixel_values = torch.stack(pixel_values) - return pixel_values - - def __call__( - self, - pixel_values, - question, - messages, - roles, - history=None, - num_patches_list=None, - IMG_START_TOKEN="", - IMG_END_TOKEN="", - IMG_CONTEXT_TOKEN="", - verbose=False, - ) -> str: - if history is None and pixel_values is not None and "" not in question: - question = "\n" + question - if num_patches_list is None: - num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] - assert pixel_values is None or len(pixel_values) == sum(num_patches_list) - img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) - self.model.img_context_token_id = img_context_token_id - - messages.append([roles[0], question]) - messages.append([roles[1], None]) - query = get_prompt(messages) - if verbose and pixel_values is not None: - image_bs = pixel_values.shape[0] - logger.info(f"dynamic ViT batch size: {image_bs}") - - for num_patches in num_patches_list: - image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN - query = query.replace("", image_tokens, 1) - return query +from QEfficient.utils.test_utils import InternProcessor def run_intern_on_aic( model_name, - prompt, - image_url, + prompts, + image_urls, messages, roles, kv_offload=False, - prefill_seq_len=3840, + prefill_seq_len=128, + ctx_len=4096, + batch_size=1, + num_patches=1, num_devices=1, num_cores=16, ): @@ -187,8 +43,12 @@ def run_intern_on_aic( model.compile( num_cores=num_cores, num_devices=num_devices, + num_patches=num_patches, + batch_size=batch_size, prefill_seq_len=prefill_seq_len, - mxfp6_matmul=False, + ctx_len=ctx_len, + mxfp6_matmul=True, + mxint8_kv_cache=True, ) ## STEP 3 -- SETUP THE PROCESSOR @@ -199,16 +59,30 @@ def run_intern_on_aic( ## STEP 4 -- PREPROCESS THE INPUTS - img = requests.get(image_url, stream=True) - image = Image.open(BytesIO(img.content)).convert("RGB") + pixel_values = [] + num_patches_list = [] + questions = [] + for i in range(len(prompts)): + img = requests.get(image_urls[i], stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") - # Images are resized to (1000, 747) for inference - image = image.resize((1000, 747)) + # Images are resized to (1000, 747) for inference + image = image.resize((1000, 747)) - # preprocess the resized image - pixel_values = internProcessor.load_image(image, max_num=12) - question = "\n" + prompt - query = internProcessor(pixel_values, question, messages, roles) + # preprocess the resized image + pixel_value = internProcessor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + pixel_values.append(pixel_value) + + question = "\n" + prompts[i] + questions.append(question) + + pixel_values = torch.cat(pixel_values, dim=0) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + query = internProcessor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) inputs = tokenizer( query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" ) @@ -217,7 +91,8 @@ def run_intern_on_aic( ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION streamer = TextStreamer(tokenizer) - model.generate(inputs=inputs, streamer=streamer, generation_len=128) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=128) + return output if __name__ == "__main__": @@ -228,8 +103,12 @@ def run_intern_on_aic( roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") # Inputs for the model - prompt = "Please describe the image in detail." - image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + + # Add additional prompts and image urls to the respective lists for multi batch compilation and inference + prompts = ["Please describe the image in detail."] + image_urls = [ + "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + ] ## Compilation parameters @@ -237,24 +116,29 @@ def run_intern_on_aic( # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. # The outputs of the Vision Encoder are then passed to the Language model via host in this case. - kv_offload = False + kv_offload = True # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to # incorporate the memory for the merged embeddings. - prefill_seq_len = 3840 + prefill_seq_len = 128 + ctx_len = 4096 num_devices = 4 num_cores = 16 + num_patches = 13 - run_intern_on_aic( + output = run_intern_on_aic( model_name=model_name, - prompt=prompt, - image_url=image_url, + prompts=prompts, + image_urls=image_urls, messages=messages, roles=roles, kv_offload=kv_offload, prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=len(prompts), + num_patches=num_patches, num_devices=num_devices, num_cores=num_cores, ) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 49d2ccf8c..3e5b5ed50 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -31,6 +31,7 @@ "microsoft/Phi-3-mini-4k-instruct", "tiiuae/falcon-7b", "Qwen/Qwen2-0.5B", + "Qwen/Qwen3-0.6B", "bigcode/starcoder2-3b", "Felladrin/Minueza-32M-Base", "wtang06/mpt-125m-c4", diff --git a/tests/transformers/models/test_image_text_to_text_models.py b/tests/transformers/models/test_image_text_to_text_models.py index be0e84d23..c25af2b9b 100644 --- a/tests/transformers/models/test_image_text_to_text_models.py +++ b/tests/transformers/models/test_image_text_to_text_models.py @@ -134,6 +134,16 @@ "Please describe the image in detail.", 2, ), + ( + "OpenGVLab/InternVL3_5-1B", + True, + 1, + 384, + 512, + "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg", + "Please describe the image in detail.", + 2, + ), # ( # "OpenGVLab/InternVL2_5-1B", # False, @@ -300,10 +310,39 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) processor = InternProcessor(model_hf, tokenizer) - img = requests.get(img_url, stream=True) - image = Image.open(BytesIO(img.content)).convert("RGB") - image = image.resize((448, 448)) + prompt = [query] + img_url = [img_url] + pixel_values = [] + num_patches_list = [] + questions = [] + for i in range(len(prompt)): + img = requests.get(img_url[i], stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + image = image.resize((448, 448)) + + # preprocess the resized image + pixel_value = processor.load_image(image, max_num=12) + num_patches_list.append(pixel_value.shape[0]) + pixel_values.append(pixel_value) + + question = "\n" + prompt[i] + questions.append(question) + + pixel_values = torch.cat(pixel_values, dim=0) + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) + + inputs = tokenizer(prompt, return_tensors="pt") + batch_size, prompt_len = inputs["input_ids"].shape + inputs["pixel_values"] = pixel_values.clone() + + generation_config = dict(max_new_tokens=max_gen_len, do_sample=False) + generation_config["eos_token_id"] = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) api_runner = ApiRunnerInternVL( batch_size, processor, @@ -315,19 +354,6 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100( max_gen_len, n_layer, ) - pixel_values = processor.load_image(image, max_num=12) - question = "\n" + query - # Chat Template information for prompt preprocessing - messages: List[List[str]] = [] - roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") - prompt = processor(pixel_values, question, messages, roles) - - inputs = tokenizer(prompt, return_tensors="pt") - batch_size, prompt_len = inputs["input_ids"].shape - inputs["pixel_values"] = pixel_values.clone() - - generation_config = dict(max_new_tokens=max_gen_len, do_sample=False) - generation_config["eos_token_id"] = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip()) pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch(model_hf, inputs, generation_config) qeff_model = QEFFAutoModelForCausalLM.from_pretrained(