diff --git a/examples/dummy_onboarding/causallm/README.md b/examples/dummy_onboarding/causallm/README.md new file mode 100644 index 000000000..52dfafcf6 --- /dev/null +++ b/examples/dummy_onboarding/causallm/README.md @@ -0,0 +1,64 @@ +# Onboarding a CausalLM Model + +## Prerequisites + +Install `qefficient-transformers` library in editable mode. +```sh +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install -e . +``` + +## Introduction + +This README provides a step-by-step guide on how to on-board a CausalLM model. The process includes setting up the environment, modifying configuration files, and running the model. +We will use a dummy model named `Blueprint` as an example and walk through the changes that have to be provided in the `qefficient_transformers` library to enable such a model. + + +## Step 1: Checking the classes in the modeling file from `transformers`library + +1. **Look into the original modelling files in the `transformers` library:** + - Locate the original model in the `transformers` library. + - `/src/transformers/models/blueprint/modeling_blueprint.py` has all the modeling classes used to construct the model. + - Locate the `pytorch_transforms.py` file in `efficient-transformers/QEfficient/transformers/models/` to see if the corresponding classes are already implemented in `qefficient_transformers`. + - It's a good reference point to see if some functionalities have already been used in a prior model. + - Check the architecture class of the model that you want to on-board. + - If it is not in `pytorch_transforms.py`, you will need to implement them and then map the class along with the other required classes in the `pytorch_transforms.py` file. + +## Step 2: Creating the custom modeling file and mappings in pytorch_tranforms.py + +1. **Adding the required modified modeling file in the `qefficient-transformers` library:** + - For our example we will create the following directory : + `/QEfficient/transformers/models/blueprint` + - Then we will add the modeling and __init__ files in this directory. + - The modeling file `modeling_blueprint.py` will have all the necessary modified modeling classes. + - The file has been annotated to explain where and why the changes are required for the model. + +2. **Add the mapping to the corresponding classes in `pytorch_transforms.py` file:** + - You will need to map the classes of the model to the ones in the `pytorch_transforms.py` file. + - If you look into `dummy_pytorch_transforms.py` file, you can see an example case for our `Blueprint` model. + - Every Mapping Class serves a specific purpose :- + - **CustomOpsTransform** + - This class has mapping for the RMSNorm class that we use for a model. + - Most of the models have the same RMSNorm, in case you need to change the RMSNorm classes, you will need to make changes in this class as we do for Gemma models. + - To add your own custom RMSNorm class when required, you can add it in `QEfficient.customop` file. + - **KVCacheTransform** + - This class handles mappings for almost all models that use a KV cache and thus generate text. + - All the custom classes that we define for a model, we add the mappings with the corresponding transformers class in this section. + - For the exception of models that don't have their modeling files in transformers library, we create this mapping via a different mapping class called **KVCacheExternalModuleMapperTransform** + - **KVCacheExternalModuleMapperTransform** + - This class is used to map to a class that is not in the transformers library. + - Here we don't perform a class replacement as we do in other mapping classes. + - We simply map the class name of the original model and then we map the methods of those classes with the custom methods that we had defined in our custom modeling file in qefficient. + +3. **Testing the implementation:** + - Once the implementation is complete, you can test it via the following instructions + - Go to the file `/efficient-transformers/tests/transformers/models/test_causal_lm_models.py` and add the appropriate model card for your model. + - For example, for `Blueprint` model, we might have a model card like `Blueprint/Blueprint-70B` on huggingface. + - Add the model card in the list `test_models_causal` and then run the test to ensure that the classes are mapped correctly. + + +## References +- [Hugging Face Transformers GitHub Repository](https://github.com/huggingface/transformers) +- [Qefficient Transformers GitHub Repository](https://github.com/quic/efficient-transformers) + diff --git a/examples/dummy_onboarding/causallm/dummy_pytorch_transforms.py b/examples/dummy_onboarding/causallm/dummy_pytorch_transforms.py new file mode 100644 index 000000000..a67264198 --- /dev/null +++ b/examples/dummy_onboarding/causallm/dummy_pytorch_transforms.py @@ -0,0 +1,659 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import warnings +from types import MethodType +from typing import Callable, Optional, Tuple, Union + +from torch import nn + +# from transformers.models.codegen.modeling_codegen import ( +# CodeGenAttention, +# CodeGenBlock, +# CodeGenForCausalLM, +# CodeGenModel, +# ) +# from transformers.models.falcon.modeling_falcon import ( +# FalconAttention, +# FalconDecoderLayer, +# FalconForCausalLM, +# FalconModel, +# ) +# from transformers.models.gemma.modeling_gemma import ( +# GemmaAttention, +# GemmaDecoderLayer, +# GemmaForCausalLM, +# GemmaModel, +# GemmaRMSNorm, +# ) +# from transformers.models.gemma2.modeling_gemma2 import ( +# Gemma2Attention, +# Gemma2DecoderLayer, +# Gemma2ForCausalLM, +# Gemma2Model, +# Gemma2RMSNorm, +# ) +# from transformers.models.gemma3.modeling_gemma3 import ( +# Gemma3Attention, +# Gemma3DecoderLayer, +# Gemma3ForCausalLM, +# Gemma3ForConditionalGeneration, +# Gemma3RMSNorm, +# Gemma3TextModel, +# ) +# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model +# from transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( +# GPTBigCodeAttention, +# GPTBigCodeBlock, +# GPTBigCodeForCausalLM, +# GPTBigCodeModel, +# ) +# from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel +# from transformers.models.granite.modeling_granite import ( +# GraniteAttention, +# GraniteForCausalLM, +# GraniteModel, +# GraniteRMSNorm, +# ) +# from transformers.models.granitemoe.modeling_granitemoe import ( +# GraniteMoeAttention, +# GraniteMoeForCausalLM, +# GraniteMoeModel, +# GraniteMoeMoE, +# GraniteMoeParallelExperts, +# GraniteMoeRMSNorm, +# GraniteMoeRotaryEmbedding, +# GraniteMoeTopKGating, +# ) +# from transformers.models.llama.modeling_llama import ( +# LlamaAttention, +# LlamaDecoderLayer, +# LlamaForCausalLM, +# LlamaModel, +# LlamaRMSNorm, +# ) +# from transformers.models.llama4.modeling_llama4 import ( +# Llama4ForCausalLM, +# Llama4ForConditionalGeneration, +# Llama4TextAttention, +# Llama4TextDecoderLayer, +# Llama4TextExperts, +# Llama4TextModel, +# Llama4TextMoe, +# Llama4TextRMSNorm, +# Llama4VisionAttention, +# Llama4VisionModel, +# ) +# from transformers.models.llava.modeling_llava import ( +# LlavaForConditionalGeneration, +# ) +# from transformers.models.llava_next.modeling_llava_next import ( +# LlavaNextForConditionalGeneration, +# ) +# from transformers.models.mistral.modeling_mistral import ( +# MistralAttention, +# MistralDecoderLayer, +# MistralForCausalLM, +# MistralModel, +# MistralRMSNorm, +# ) +# from transformers.models.mixtral.modeling_mixtral import ( +# MixtralAttention, +# MixtralDecoderLayer, +# MixtralForCausalLM, +# MixtralModel, +# MixtralRMSNorm, +# MixtralSparseMoeBlock, +# ) +# from transformers.models.mllama.modeling_mllama import ( +# MllamaCrossAttentionDecoderLayer, +# MllamaForCausalLM, +# MllamaForConditionalGeneration, +# MllamaRotaryEmbedding, +# MllamaSelfAttentionDecoderLayer, +# MllamaTextCrossAttention, +# MllamaTextModel, +# MllamaTextRMSNorm, +# MllamaTextSelfAttention, +# MllamaVisionModel, +# ) +# from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel +# from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel +# from transformers.models.phi3.modeling_phi3 import ( +# Phi3Attention, +# Phi3DecoderLayer, +# Phi3ForCausalLM, +# Phi3Model, +# Phi3RMSNorm, +# ) +# from transformers.models.qwen2.modeling_qwen2 import ( +# Qwen2Attention, +# Qwen2DecoderLayer, +# Qwen2ForCausalLM, +# Qwen2Model, +# Qwen2RMSNorm, +# ) +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRMSNorm, +) + +# from transformers.models.starcoder2.modeling_starcoder2 import ( +# Starcoder2Attention, +# Starcoder2DecoderLayer, +# Starcoder2ForCausalLM, +# Starcoder2Model, +# ) +# from transformers.models.whisper.modeling_whisper import ( +# WhisperAttention, +# WhisperDecoder, +# WhisperDecoderLayer, +# WhisperEncoder, +# WhisperForConditionalGeneration, +# WhisperModel, +# WhisperPositionalEmbedding, +# ) +from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function + +# from QEfficient.transformers.models.codegen.modeling_codegen import ( +# QEffCodeGenAttention, +# QeffCodeGenBlock, +# QEffCodeGenForCausalLM, +# QEffCodeGenModel, +# ) +# from QEfficient.transformers.models.falcon.modeling_falcon import ( +# QEffFalconAttention, +# QEffFalconDecoderLayer, +# QEffFalconForCausalLM, +# QEffFalconModel, +# ) +# from QEfficient.transformers.models.gemma.modeling_gemma import ( +# QEffGemmaAttention, +# QEffGemmaDecoderLayer, +# QEffGemmaForCausalLM, +# QEffGemmaModel, +# ) +# from QEfficient.transformers.models.gemma2.modeling_gemma2 import ( +# QEffGemma2Attention, +# QEffGemma2DecoderLayer, +# QEffGemma2ForCausalLM, +# QEffGemma2Model, +# ) +# from QEfficient.transformers.models.gemma3.modeling_gemma3 import ( +# QEffGemma3Attention, +# QEffGemma3CustomRMSNormAIC, +# QEffGemma3DecoderLayer, +# QEffGemma3ForCausalLMModel, +# QEffGemma3ForConditionalGeneration, +# QEffGemma3TextModel, +# ) +# from QEfficient.transformers.models.gpt2.modeling_gpt2 import ( +# QEffGPT2Attention, +# QEffGPT2Block, +# QEffGPT2LMHeadModel, +# QEffGPT2Model, +# ) +# from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import ( +# QEffGPTBigCodeAttention, +# QEffGPTBigCodeBlock, +# QEffGPTBigCodeForCausalLM, +# QEffGPTBigCodeModel, +# ) +# from QEfficient.transformers.models.gptj.modeling_gptj import ( +# QEffGPTJAttention, +# QEffGPTJBlock, +# QEffGPTJForCausalLM, +# QEffGPTJModel, +# ) +# from QEfficient.transformers.models.granite.modeling_granite import ( +# QEffGraniteAttention, +# QEffGraniteForCausalLM, +# QEffGraniteModel, +# ) +# from QEfficient.transformers.models.granitemoe.modeling_granitemoe import ( +# QEffGraniteMoeAttention, +# QEffGraniteMoeForCausalLM, +# QEffGraniteMoeModel, +# QEffGraniteMoeMoE, +# QEffGraniteMoeParallelExperts, +# QEffGraniteMoeRotaryEmbedding, +# QEffGraniteMoeTopKGating, +# ) +# from QEfficient.transformers.models.grok_1.modeling_grok1 import ( +# QEFFGrok1CustomRMSNormAIC, +# QEffGrok1DecoderLayer, +# QEffGrok1Model, +# QEffGrok1ModelForCausalLM, +# QEffGrok1MoeBlock, +# QEffGrok1MultiHeadAttention, +# ) +# from QEfficient.transformers.models.internvl.modeling_internvl import ( +# QEffInternVisionEmbeddings, +# QEffInternVLModel, +# ) +# from QEfficient.transformers.models.llama.modeling_llama import ( +# QEffLlamaAttention, +# QEffLlamaDecoderLayer, +# QEffLlamaForCausalLM, +# QEffLlamaModel, +# ) +# from QEfficient.transformers.models.llama4.modeling_llama4 import ( +# QEffLlama4ForCausalLM, +# QEffLlama4ForConditionalGeneration, +# QEffLlama4TextAttention, +# QEffLlama4TextDecoderLayer, +# QEffLlama4TextExperts, +# QEffLlama4TextModel, +# QEffLlama4TextMoe, +# QEffLlama4VisionAttention, +# QEffLlama4VisionModel, +# ) +# from QEfficient.transformers.models.llava.modeling_llava import ( +# QEffLlavaForConditionalGeneration, +# ) +# from QEfficient.transformers.models.llava_next.modeling_llava_next import ( +# QEffLlavaNextForConditionalGeneration, +# ) +# from QEfficient.transformers.models.mistral.modeling_mistral import ( +# QEffMistralAttention, +# QEffMistralDecoderLayer, +# QEffMistralForCausalLM, +# QEffMistralModel, +# ) +# from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( +# QEffMixtralAttention, +# QeffMixtralDecoderLayer, +# QEffMixtralForCausalLM, +# QEffMixtralModel, +# QEffMixtralSparseMoeBlock, +# ) +# from QEfficient.transformers.models.mllama.modeling_mllama import ( +# QEffMllamaCrossAttentionDecoderLayer, +# QEffMllamaForCausalLM, +# QEffMllamaForConditionalGeneration, +# QEffMllamaRotaryEmbedding, +# QEffMllamaSelfAttentionDecoderLayer, +# QEffMllamaTextCrossAttentionSingleQPC, +# QEffMllamaTextCrossAttentionTwoQPC, +# QEffMllamaTextModel, +# QEffMllamaTextSelfAttention, +# QEffMllamaVisionModel, +# ) +# from QEfficient.transformers.models.mpt.modeling_mpt import ( +# QEffMptAttention, +# QEffMptBlock, +# QEffMptForCausalLM, +# QEFfMptModel, +# ) +# from QEfficient.transformers.models.phi.modeling_phi import ( +# QEffPhiAttention, +# QEffPhiDecoderLayer, +# QEffPhiForCausalLM, +# QEffPhiModel, +# ) +# from QEfficient.transformers.models.phi3.modeling_phi3 import ( +# QEffPhi3Attention, +# QEffPhi3DecoderLayer, +# QEffPhi3ForCausalLM, +# QEffPhi3Model, +# ) +# from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( +# QEffQwen2Attention, +# QEffQwen2DecoderLayer, +# QEffQwen2ForCausalLM, +# QEffQwen2Model, +# ) +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) + +# from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( +# QEffStarcoder2Attention, +# QEFFStarcoder2DecoderLayer, +# QEffStarcoder2ForCausalLM, +# QEffStarcoder2Model, +# ) +# from QEfficient.transformers.models.whisper.modeling_whisper import ( +# QEffWhisperAttention, +# QEffWhisperDecoder, +# QEffWhisperDecoderLayer, +# QEffWhisperEncoder, +# QEffWhisperForConditionalGeneration, +# QEffWhisperModel, +# QEffWhisperPositionalEmbedding, +# ) +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.sampler.sampler import sampler_forward +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" + + +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + # GemmaRMSNorm: GemmaCustomRMSNormAIC, + # Gemma2RMSNorm: GemmaCustomRMSNormAIC, + # LlamaRMSNorm: CustomRMSNormAIC, + # Llama4TextRMSNorm: CustomRMSNormAIC, + # MistralRMSNorm: CustomRMSNormAIC, + # MixtralRMSNorm: CustomRMSNormAIC, + # Phi3RMSNorm: CustomRMSNormAIC, + # Qwen2RMSNorm: CustomRMSNormAIC, + BlueprintRMSNorm: CustomRMSNormAIC, + # MllamaTextRMSNorm: CustomRMSNormAIC, + # GraniteRMSNorm: CustomRMSNormAIC, + # GraniteMoeRMSNorm: CustomRMSNormAIC, + # Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC, + } + + +class KVCacheTransform(ModuleMappingTransform): + _module_mapping = { + # # CodeGen + # CodeGenAttention: QEffCodeGenAttention, + # CodeGenBlock: QeffCodeGenBlock, + # CodeGenModel: QEffCodeGenModel, + # CodeGenForCausalLM: QEffCodeGenForCausalLM, + # # Falcon + # FalconAttention: QEffFalconAttention, + # FalconDecoderLayer: QEffFalconDecoderLayer, + # FalconModel: QEffFalconModel, + # FalconForCausalLM: QEffFalconForCausalLM, + # # GPT2 + # GPT2Attention: QEffGPT2Attention, + # GPT2Block: QEffGPT2Block, + # GPT2Model: QEffGPT2Model, + # GPT2LMHeadModel: QEffGPT2LMHeadModel, + # # GPTJ + # GPTJAttention: QEffGPTJAttention, + # GPTJBlock: QEffGPTJBlock, + # GPTJModel: QEffGPTJModel, + # GPTJForCausalLM: QEffGPTJForCausalLM, + # # Llama + # LlamaAttention: QEffLlamaAttention, + # LlamaDecoderLayer: QEffLlamaDecoderLayer, + # LlamaModel: QEffLlamaModel, + # LlamaForCausalLM: QEffLlamaForCausalLM, + # # Llama4 + # Llama4TextAttention: QEffLlama4TextAttention, + # Llama4ForCausalLM: QEffLlama4ForCausalLM, + # Llama4TextDecoderLayer: QEffLlama4TextDecoderLayer, + # Llama4TextModel: QEffLlama4TextModel, + # Llama4TextMoe: QEffLlama4TextMoe, + # Llama4ForConditionalGeneration: QEffLlama4ForConditionalGeneration, + # Llama4VisionAttention: QEffLlama4VisionAttention, + # Llama4VisionModel: QEffLlama4VisionModel, + # Llama4TextExperts: QEffLlama4TextExperts, + # # Llava + # LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, + # # Llava Next + # LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration, + # # Gemma + # GemmaAttention: QEffGemmaAttention, + # GemmaDecoderLayer: QEffGemmaDecoderLayer, + # GemmaModel: QEffGemmaModel, + # GemmaForCausalLM: QEffGemmaForCausalLM, + # # Gemma2 + # Gemma2Attention: QEffGemma2Attention, + # Gemma2DecoderLayer: QEffGemma2DecoderLayer, + # Gemma2Model: QEffGemma2Model, + # Gemma2ForCausalLM: QEffGemma2ForCausalLM, + # # Gemma3 + # Gemma3Attention: QEffGemma3Attention, + # Gemma3DecoderLayer: QEffGemma3DecoderLayer, + # Gemma3TextModel: QEffGemma3TextModel, + # Gemma3ForCausalLM: QEffGemma3ForCausalLMModel, + # Gemma3ForConditionalGeneration: QEffGemma3ForConditionalGeneration, + # # Granite + # GraniteModel: QEffGraniteModel, + # GraniteForCausalLM: QEffGraniteForCausalLM, + # GraniteAttention: QEffGraniteAttention, + # # GraniteMoe + # GraniteMoeModel: QEffGraniteMoeModel, + # GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM, + # GraniteMoeAttention: QEffGraniteMoeAttention, + # GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding, + # GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts, + # GraniteMoeTopKGating: QEffGraniteMoeTopKGating, + # GraniteMoeMoE: QEffGraniteMoeMoE, + # # mllama + # MllamaTextRMSNorm: CustomRMSNormAIC, + # MllamaTextSelfAttention: QEffMllamaTextSelfAttention, + # MllamaSelfAttentionDecoderLayer: QEffMllamaSelfAttentionDecoderLayer, + # MllamaCrossAttentionDecoderLayer: QEffMllamaCrossAttentionDecoderLayer, + # MllamaRotaryEmbedding: QEffMllamaRotaryEmbedding, + # MllamaVisionModel: QEffMllamaVisionModel, + # MllamaTextModel: QEffMllamaTextModel, + # MllamaForCausalLM: QEffMllamaForCausalLM, + # MllamaForConditionalGeneration: QEffMllamaForConditionalGeneration, + # # Mistral + # MistralAttention: QEffMistralAttention, + # MistralDecoderLayer: QEffMistralDecoderLayer, + # MistralModel: QEffMistralModel, + # MistralForCausalLM: QEffMistralForCausalLM, + # # Mixtral + # MixtralAttention: QEffMixtralAttention, + # MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, + # MixtralDecoderLayer: QeffMixtralDecoderLayer, + # MixtralModel: QEffMixtralModel, + # MixtralForCausalLM: QEffMixtralForCausalLM, + # # Mpt + # MptAttention: QEffMptAttention, + # MptBlock: QEffMptBlock, + # MptModel: QEFfMptModel, + # MptForCausalLM: QEffMptForCausalLM, + # # Phi3 + # Phi3Attention: QEffPhi3Attention, + # Phi3DecoderLayer: QEffPhi3DecoderLayer, + # Phi3Model: QEffPhi3Model, + # Phi3ForCausalLM: QEffPhi3ForCausalLM, + # # Phi + # PhiAttention: QEffPhiAttention, + # PhiDecoderLayer: QEffPhiDecoderLayer, + # PhiModel: QEffPhiModel, + # PhiForCausalLM: QEffPhiForCausalLM, + # # Qwen2 + # Qwen2Attention: QEffQwen2Attention, + # Qwen2DecoderLayer: QEffQwen2DecoderLayer, + # Qwen2Model: QEffQwen2Model, + # Qwen2ForCausalLM: QEffQwen2ForCausalLM, + # Blueprint + BlueprintAttention: QEffBlueprintAttention, + BlueprintDecoderLayer: QEffBlueprintDecoderLayer, + BlueprintModel: QEffBlueprintModel, + BlueprintForCausalLM: QEffBlueprintForCausalLM, + # # Starcoder2 + # Starcoder2Attention: QEffStarcoder2Attention, + # Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, + # Starcoder2Model: QEffStarcoder2Model, + # Starcoder2ForCausalLM: QEffStarcoder2ForCausalLM, + # # GptBigcode + # GPTBigCodeAttention: QEffGPTBigCodeAttention, + # GPTBigCodeBlock: QEffGPTBigCodeBlock, + # GPTBigCodeModel: QEffGPTBigCodeModel, + # GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, + # # Whisper encoder and decoder layers + # WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding, + # WhisperAttention: QEffWhisperAttention, + # WhisperDecoderLayer: QEffWhisperDecoderLayer, + # WhisperEncoder: QEffWhisperEncoder, + # WhisperDecoder: QEffWhisperDecoder, + # WhisperModel: QEffWhisperModel, + # WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + # QEffLlamaForCausalLM, + # QEffQwen2ForCausalLM, + QEffBlueprintForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None) + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + elif speculative_model_type not in ( + supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys()) + ): + raise ValueError( + f"Specualtive model type {speculative_model_type} is not supported. we currently only support {supported_spd_model_types}" + ) + elif (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + # build and attach draft mlp + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) + transformed = True + else: + raise NotImplementedError( + f"model class {model_class} does not yet support returning multiple logits to keep." + ) + kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp + return model, transformed + + +class SamplerTransform: + """ + Add nodes at the output of any generic QEffForCausalLM model to enable the + sampling of next tokens at the device (instead of the host) and return the + next tokens and/or probability distributions. + + Note: To achieve this, the generic QEffForCausalLM model must provide the + logits as output. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # Llama + # QEffLlamaForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + if qaic_config is None or not qaic_config.get("include_sampler", False): + return model, transformed + elif (model_class := model.__class__) in cls._module_mapping: + model.old_forward = model.forward + model.forward = MethodType(sampler_forward, model) + transformed = True + else: + raise NotImplementedError(f"Model class {model_class} does not support on device sampling.") + return model, transformed + + +class VlmKVOffloadTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + # # Llama + # MllamaTextCrossAttention: QEffMllamaTextCrossAttentionTwoQPC, + } + + +class VlmNoKVOffloadTransform(ModuleMappingTransform): + # supported architectures + _module_mapping = { + # # Llama + # MllamaTextCrossAttention: QEffMllamaTextCrossAttentionSingleQPC, + } + + +class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_string_replace_method = { + # "InternVLChatModel": { + # "forward": QEffInternVLModel.forward, + # "get_blueprint_inputs": QEffInternVLModel.get_blueprint_inputs, + # "get_specializations": QEffInternVLModel.get_specializations, + # "get_onnx_dynamic_axes": QEffInternVLModel.get_onnx_dynamic_axes, + # "get_output_names": QEffInternVLModel.get_output_names, + # "get_inputs_info": QEffInternVLModel.get_inputs_info, + # "get_qeff_vision_encoder": QEffInternVLModel.get_qeff_vision_encoder, + # "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, + # }, + # "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # # Mapping for grok1 model + # "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, + # "Grok1Model": { + # "forward": QEffGrok1Model.forward, + # "__qeff_init__": QEffGrok1Model.__qeff_init__, + # }, + # "DecoderLayer": { + # "forward": QEffGrok1DecoderLayer.forward, + # "__qeff_init__": QEffGrok1DecoderLayer.__qeff_init__, + # }, + # "MoeBlock": {"forward": QEffGrok1MoeBlock.forward}, + # "MultiHeadAttention": { + # "forward": QEffGrok1MultiHeadAttention.forward, + # }, + # "RMSNorm": { + # "forward": QEFFGrok1CustomRMSNormAIC.forward, + # }, + } + + _match_class_replace_method = {} + + +class PoolingTransform: + """ + Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. + The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. + """ + + @classmethod + def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: + transformed = False + pooling_method = ( + POOLING_MAP[pooling] + if isinstance(pooling, str) and pooling in POOLING_MAP + else validate_user_pooling_function(pooling) + ) + model = PooledModel(model, pooling_method) + warnings.warn("Pooling is applied to the model.") + return model, transformed diff --git a/examples/dummy_onboarding/causallm/modeling_dummy.py b/examples/dummy_onboarding/causallm/modeling_dummy.py new file mode 100644 index 000000000..195c9d7db --- /dev/null +++ b/examples/dummy_onboarding/causallm/modeling_dummy.py @@ -0,0 +1,394 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Blueprint model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintConfig, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRotaryEmbedding, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffBlueprintRotaryEmbedding(BlueprintRotaryEmbedding): + """ + Add the required Rotary Embedding functionality to the model based on the Class in the transformers modeling file. + The purpose of this class is to precompute sin and cos values for the rotary embedding and cache it for faster inference. + This class is more or less the same for all models that are onboarded. + """ + + def __init__(self, config: BlueprintConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors. + + We modify this method to enable the application of the rotary embedding based on position_ids + instead of seq_len. This is needed as our modified modelling accepts position_ids and not + the attention_mask as an input. + """ + # + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + """ + Implements the forward pass of Eager Attention for the model. + We explicitly support Eager mode based attention on our device. + The method would mostly be generic so we don't expect it to have much changes. + MIN_MASKED_ATTENTION_VALUE is a special value which helps our compiler know what -inf should be represented by. + """ + pass + + +class QEffBlueprintAttention(BlueprintAttention): + """ + Here we'll setup the forward pass of the Attention module as implemented in the original model. + We initialize our own RotaryEmbedding module via __qeff_init__ method call. + + """ + + # < We load our own custom class for the rotary embedding to enable supporting position_ids> + # Since we map the custom classes to the original classes, __init__ method wouldn't work as expected, + # Hence we use __qeff_init__ method to initialize something while the mapping happens. + + def __qeff_init__(self): + self.rotary_emb = QEffBlueprintRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Most of the implementation remains the same as original forward method. + The parts where difference occurs are the way we apply the rotary embeddings. + Also, we return the past_key_values instead of storing it in the default transformers cache. + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + # We build the rotary embeddings different from the transformers method. + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # Application of the rotary embeddings requires position_ids as well. + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # < We add all the required items for cache kwargs which would enable updating QEffDynamicCache > + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # < We override the attention_interface method with our own to enable Eager Attention> + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffBlueprintDecoderLayer(BlueprintDecoderLayer): + """ + Overrides the forward method of the original BlueprintDecoderLayer. + Only changes being that the past_key_value is returned and `self.self_attn` method + is now an object of QEffBlueprintAttention instead. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + The modified forward function also stores and returns the past_key_value. + Every other operation remains the same. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # < Self attention would also have to return the past_key_value as well and we capture it here> + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffBlueprintModel(BlueprintModel): + """ + Replaces the original BlueprintModel with a modified version. + We initialize the custom `QEffDynamicCache` for past_key_values here instead of the DynamicCache class. + This custom Cache class has all the required custom ops to perform CtxScatter/CtxGather as well as other required operations. + This enables us to cache the past key values in the way we want for AIC. The component won't require any changes mostly. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + # < We create the custom QEffDynamicCache here to be used during the AIC execution> + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffBlueprintForCausalLM(BlueprintForCausalLM): + """ + No major changes are needed in the forward method of this class, it is the entry point for the model during inference. + We add the additionally required parameters and pass those down the line as well. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # < We add the additional parameters that we use for our models here and pass them down the line > + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )