diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..ea0e155f7 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -13,8 +13,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import torch import transformers -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import padding_check_and_fix @@ -313,7 +314,10 @@ def calculate_latency(total_decoded_tokens, loop_start, start, end, decode_pause def cloud_ai_100_exec_kv( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + lang_qpc_path: str, + processor: Optional[AutoImageProcessor] = None, + vision_qpc_path: Optional[str] = None, + images: Optional[str] = None, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, @@ -372,7 +376,7 @@ def cloud_ai_100_exec_kv( exec_info = QEfficient.cloud_ai_100_exec_kv(tokenizer=tokenizer, qpc_path=qpc_path, prompt="Hi there!!", device_id=[0]) """ - batch_size, ctx_len, full_batch_size = get_compilation_dims(qpc_path) + batch_size, ctx_len, full_batch_size = get_compilation_dims(lang_qpc_path) prompt: List[str] = get_input_prompts(prompt, prompts_txt_file_path) prompt = fix_prompts(prompt, batch_size, full_batch_size) if prompt_to_lora_id_mapping is not None: @@ -381,7 +385,9 @@ def cloud_ai_100_exec_kv( ) generate_text = TextGeneration( tokenizer=tokenizer, - qpc_path=qpc_path, + processor=processor, + lang_qpc_path=lang_qpc_path, + vision_qpc_path=vision_qpc_path, device_id=device_id, ctx_len=ctx_len, enable_debug_logs=enable_debug_logs, @@ -393,6 +399,37 @@ def cloud_ai_100_exec_kv( sampling_params=sampling_params, ) + if full_batch_size is None: + exec_info = [ + generate_text.generate( + prompt=prompt[i : i + batch_size], + generation_len=generation_len, + stream=stream, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) + for i in range(0, len(prompt), batch_size) + ] + prefill_time = np.average([info.perf_metrics.prefill_time for info in exec_info]) + decode_perf = np.average([info.perf_metrics.decode_perf for info in exec_info]) + total_perf = np.average([info.perf_metrics.total_perf for info in exec_info]) + total_time = np.average([info.perf_metrics.total_time for info in exec_info]) + generated_texts = [info.generated_texts for info in exec_info] + generated_ids = [info.generated_ids for info in exec_info] + + exec_info = CloudAI100ExecInfo( + batch_size=batch_size, + generated_texts=generated_texts, + generated_ids=generated_ids, + perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time), + ) + else: + exec_info = generate_text.generate( + prompt=prompt, + images=images, + generation_len=generation_len, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, + ) + for _ in range(0, int(iteration)): if full_batch_size is None: exec_info = [ @@ -427,7 +464,9 @@ class QEffTextGenerationBase: def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + lang_qpc_path: str, + processor: Optional[AutoImageProcessor] = None, + vision_qpc_path: Optional[str] = None, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -445,11 +484,17 @@ def __init__( self.sampling_params = sampling_params # Load QPC - self._session = QAICInferenceSession(qpc_path, device_id, enable_debug_logs=enable_debug_logs) + self._lang_session = None + self._vision_session = None + if not lang_qpc_path: + raise TypeError("Please run compile API for language model first!") + self._lang_session = QAICInferenceSession(lang_qpc_path, device_id, activate=False) + if vision_qpc_path: + self._vision_session = QAICInferenceSession(vision_qpc_path, device_id, activate=False) # Validate sampler inputs for On-Device Sampling self.include_sampler = validate_sampler_inputs( - session_inputs=set(self._session.input_names), include_sampler=include_sampler + session_inputs=set(self._lang_session.input_names), include_sampler=include_sampler ) # Fetch the variables from the QPC @@ -474,10 +519,23 @@ def __init__( self.generation_len = None self.tokenizer = tokenizer + self.processor = processor self._set_tokenizer_params() # set tokenizer params # Skip inputs/outputs - self._session.skip_buffers( - [x for x in self._session.input_names + self._session.output_names if x.startswith("past_")] + if self._vision_session: + self._vision_session.skip_buffers( + [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + self._lang_session.skip_buffers( + [ + x + for x in self._lang_session.input_names + self._lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) def _set_tokenizer_params(self): @@ -502,13 +560,16 @@ def _fetch_full_batch_size( """ full_batch_size = None - if "batch_index" in self._session.binding_index_map: - if self._session.allowed_shapes: + if "batch_index" in self._lang_session.binding_index_map: + if self._lang_session.allowed_shapes: full_batch_size, _ = [ - x[self._session.binding_index_map["batch_index"]][1][0] for x in self._session.allowed_shapes + x[self._lang_session.binding_index_map["batch_index"]][1][0] + for x in self._lang_session.allowed_shapes ] else: - full_batch_size, _ = self._session.bindings[self._session.binding_index_map["batch_index"]].dims + full_batch_size, _ = self._lang_session.bindings[ + self._lang_session.binding_index_map["batch_index"] + ].dims return full_batch_size def _fetch_batch_size_prefill_seq_len( @@ -521,15 +582,17 @@ def _fetch_batch_size_prefill_seq_len( batch_size: The batch size fetched from the session's bindings or allowed shapes. prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes. """ - if self._session.allowed_shapes: + if self._lang_session.allowed_shapes: batch_size = max( - [x[self._session.binding_index_map["input_ids"]][1][0] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][0] for x in self._lang_session.allowed_shapes] ) prefill_seq_len = max( - [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] ) else: - batch_size, prefill_seq_len = self._session.bindings[self._session.binding_index_map["input_ids"]].dims + batch_size, prefill_seq_len = self._lang_session.bindings[ + self._lang_session.binding_index_map["input_ids"] + ].dims return batch_size, prefill_seq_len def _fetch_decode_seq_len( @@ -542,9 +605,9 @@ def _fetch_decode_seq_len( decode_seq_len: The decode sequence length fetched from the session's bindings or allowed shapes. """ decode_seq_len = None - if self._session.allowed_shapes: + if self._lang_session.allowed_shapes: decode_seq_len = min( - [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] + [x[self._lang_session.binding_index_map["input_ids"]][1][1] for x in self._lang_session.allowed_shapes] ) return decode_seq_len @@ -563,10 +626,10 @@ def _fetch_vocab_size( if self.include_sampler else "logits" ) - if self._session.allowed_shapes: - return [x[self._session.binding_index_map[key]] for x in self._session.allowed_shapes][0][1][2] + if self._lang_session.allowed_shapes: + return [x[self._lang_session.binding_index_map[key]] for x in self._lang_session.allowed_shapes][0][1][2] - return self._session.bindings[self._session.binding_index_map[key]].dims[2] + return self._lang_session.bindings[self._lang_session.binding_index_map[key]].dims[2] def _fetch_generation_len(self, generation_len, max_gen_len): """ @@ -702,7 +765,7 @@ def update_decode_input(self, outputs, position_ids, generation_len, decode_batc self.generation_len[decode_batch_id or slice(None)] = generation_len return next_token_id - def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + def run_prefill_for_all_inputs(self, image_queue, prompt_queue, generation_len): """ Runs prefill for all inputs in the prompt queue and updates the decode input. @@ -713,12 +776,20 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len): generation_len (int): The generation length. """ + next_prompt = None + next_image = None for decode_batch_id in range(self.full_batch_size): - next_prompt = prompt_queue.popleft() + if prompt_queue: + next_prompt = prompt_queue.popleft() + if image_queue: + next_image = image_queue.popleft() # run prefill for num_chunks outputs, position_ids, generation_len = self.run_prefill( - next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + next_prompt, + next_image, + generation_len, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) @@ -733,14 +804,45 @@ def _set_output_buffers(self, batch_size: int = 1, sequence_length: int = 1): if self.include_sampler: if self.return_pdfs: probs_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"probs": probs_out_placeholder}) + self._lang_session.set_buffers({"probs": probs_out_placeholder}) next_tokens_out_placeholder = np.zeros((batch_size, sequence_length, 1), dtype=np.int64) - self._session.set_buffers({"next_tokens": next_tokens_out_placeholder}) + self._lang_session.set_buffers({"next_tokens": next_tokens_out_placeholder}) else: logits_out_placeholder = np.zeros((batch_size, sequence_length, self._vocab_size), dtype=np.float32) - self._session.set_buffers({"logits": logits_out_placeholder}) + self._lang_session.set_buffers({"logits": logits_out_placeholder}) + + if self._vision_session: + vision_embeds_out_placeholder = np.zeros((2448, 5120), dtype=np.float16) + self._vision_session.set_buffers({"vision_embeds": vision_embeds_out_placeholder}) + + def prepare_vision_language_inputs(self, prompt, image_url): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": prompt}, + ], + }, + ] + inputs = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + return inputs - def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + def run_prefill( + self, + prompt: str, + image: Optional[str] = None, + generation_len: Optional[int] = None, + prefill_logit_bs=1, + decode_batch_id=None, + ): """ Runs prefill for a given prompt and generation length. @@ -757,8 +859,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i position_ids (array): The position IDs. generation_len (int): The generation length. """ + # Run prefill - inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + if image: + inputs = self.prepare_vision_language_inputs(prompt, image) + else: + inputs = self.tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -self._prefill_seq_len) # ceil divide without float @@ -772,44 +879,108 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i # Set the prefill output buffers self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) - inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) - inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) - inputs.pop("token_type_ids", None) + vision_inputs = {} + vision_outputs = {} + if image: + pad_token_id = 1 + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -self._prefill_seq_len) # ceil divide without float + padded_len = num_chunks * self._prefill_seq_len # Convert to a multiple of prompt_len + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + pad_token_id, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + + # Run vision prefill + if vision_inputs: + self._vision_session.activate() + vision_outputs = self._vision_session.run(vision_inputs) + self._vision_session.deactivate() + else: + inputs = self.tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs.pop("token_type_ids", None) + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + # not_mllama = hasattr(self.model.config, "model_type") and self.model.config.model_type != "mllama" + # if not_mllama: + if image: + lang_inputs["image_idx"] = np.array([[0]]) + + self._lang_session.activate() + self._lang_session.set_buffers(vision_outputs) if decode_batch_id is not None: - inputs["batch_index"] = decode_batch_id + lang_inputs["batch_index"] = decode_batch_id if self.is_tlm: - inputs["num_logits_to_keep"] = np.zeros((1, 1)) + lang_inputs["num_logits_to_keep"] = np.zeros((1, 1)) if self.include_sampler: - inputs["last_accepted_output_tokens"] = inputs["input_ids"] + lang_inputs["last_accepted_output_tokens"] = lang_inputs["input_ids"] for op in Constants.SAMPLER_OPS: if decode_batch_id is not None: - inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] + lang_inputs[op] = self.sampling_params[op][decode_batch_id.flatten()] else: - inputs[op] = self.sampling_params[op] + lang_inputs[op] = self.sampling_params[op] if self._prompt_to_lora_id_mapping_prefill: if self.full_batch_size: - inputs["lora_ids"] = np.array( + lang_inputs["lora_ids"] = np.array( self._prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64 ).reshape(1, 1) else: batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] - inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + lang_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + + # Run language prefill for i in range(num_chunks): - chunk_inputs = inputs.copy() - chunk_inputs["input_ids"] = inputs["input_ids"][ + chunk_inputs = lang_inputs.copy() + chunk_inputs["input_ids"] = lang_inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] - chunk_inputs["position_ids"] = inputs["position_ids"][ + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len ] if self.include_sampler: chunk_inputs["last_accepted_output_tokens"] = chunk_inputs["input_ids"] - outputs = self._session.run(chunk_inputs) + outputs = self._lang_session.run(chunk_inputs) + if image: + chunk_inputs["image_idx"] = outputs["image_idx_output"] if self._write_io_dir is not None: write_io_files(inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Skip inputs/outputs again + self._lang_session.skip_buffers( + [ + x + for x in self._lang_session.input_names + self._lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + self._lang_session.deactivate() + return ( outputs, position_ids, @@ -848,7 +1019,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): decode_inputs = self.prepare_decode_inputs() while prompt_queue or current_decode_ongoing.any(): - outputs = self._session.run(decode_inputs) + self._lang_session.activate() + outputs = self._lang_session.run(decode_inputs) # Prepare inputs for next iteration next_token_id = self._fetch_next_token_id(outputs) @@ -862,8 +1034,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): start = perf_counter() # run prefill for next prompt input. outputs, position_ids, generation_len = self.run_prefill( - prompt_queue.popleft(), - generation_len, + prompt=prompt_queue.popleft(), + generation_len=generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), ) @@ -898,6 +1070,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): generated_id_current_index[decode_batch_id] += 1 + self._lang_session.deactivate() + return decode_pause_time def run_decode( @@ -919,13 +1093,14 @@ def run_decode( logits_out_placeholder = np.zeros( (self.batch_size, self._decode_seq_len, self._vocab_size), dtype=np.float32 ) - self._session.set_buffers({"logits": logits_out_placeholder}) + self._lang_session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + self._lang_session.activate() for num_token in range(1, generation_len): if streamer: streamer.put(decode_inputs["input_ids"][0]) - outputs = self._session.run(decode_inputs) + outputs = self._lang_session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -941,6 +1116,7 @@ def run_decode( if finished_sequences.all() and not automation: break + self._lang_session.deactivate() return num_token def generate_decode_stream(self, decode_inputs, generation_len, automation): @@ -957,9 +1133,10 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): token_id (int): The token generated in the decoding process. """ finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id + self._lang_session.activate() for num_token in range(1, generation_len): yield decode_inputs["input_ids"] - outputs = self._session.run(decode_inputs) + outputs = self._lang_session.run(decode_inputs) if self._write_io_dir is not None: write_io_files(decode_inputs, outputs, self._write_io_dir, "decode", "aic_batch_io", True, False) @@ -973,6 +1150,7 @@ def generate_decode_stream(self, decode_inputs, generation_len, automation): if finished_sequences.all() and not automation: break + self._lang_session.deactivate() yield decode_inputs["input_ids"] # yield the last token @@ -980,7 +1158,9 @@ class TextGeneration: def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - qpc_path: str, + lang_qpc_path: str, + processor: Optional[AutoImageProcessor] = None, + vision_qpc_path: Optional[str] = None, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -993,7 +1173,9 @@ def __init__( ) -> None: self._qaic_model = QEffTextGenerationBase( tokenizer=tokenizer, - qpc_path=qpc_path, + lang_qpc_path=lang_qpc_path, + processor=processor, + vision_qpc_path=vision_qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, device_id=device_id, @@ -1006,9 +1188,11 @@ def __init__( ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer + self._processor = self._qaic_model.processor self._ctx_len = ctx_len self._perf_metrics = None self._prompt_queue = None + self._image_queue = None self._text_streamer = None @property @@ -1018,6 +1202,7 @@ def perf_metrics(self): def _setup_model_execution_inputs( self, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1035,6 +1220,8 @@ def _setup_model_execution_inputs( # Create a prompt queue. self._prompt_queue = deque(prompt) + if images: + self._image_queue = deque(images) # Initialize np arrays for storing the prefill output for all the decode batch size. num_prompts = len(self._prompt_queue) @@ -1064,12 +1251,14 @@ def _regular_model_execution( :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) + self._setup_model_execution_inputs( + prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping + ) if stream and self._text_streamer is None: self._text_streamer = transformers.TextStreamer(self._tokenizer) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1090,6 +1279,7 @@ def _regular_model_execution( def _continuous_batching_execution( self, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, prompt_to_lora_id_mapping: Optional[List[int]] = None, ): @@ -1105,10 +1295,10 @@ def _continuous_batching_execution( Returns: :tuple: A tuple containing performance metrics and generated texts. """ - self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) + self._setup_model_execution_inputs(prompt, images, generation_len, prompt_to_lora_id_mapping) self._qaic_model.batch_index = np.arange(self._full_batch_size).reshape(-1, 1) start = perf_counter() - self._qaic_model.run_prefill_for_all_inputs(self._prompt_queue, generation_len) + self._qaic_model.run_prefill_for_all_inputs(self._image_queue, self._prompt_queue, generation_len) loop_start = perf_counter() # Start decode loop timer decode_pause_time = self._qaic_model.run_continuous_batching_decode(self._prompt_queue, generation_len) @@ -1152,7 +1342,7 @@ def generate_stream_tokens( self._setup_model_execution_inputs(prompt, generation_len, prompt_to_lora_id_mapping) start = perf_counter() outputs, position_ids, generation_len = self._qaic_model.run_prefill( - prompt, generation_len, prefill_logit_bs=self._qaic_model.batch_size + prompt=prompt, generation_len=generation_len, prefill_logit_bs=self._qaic_model.batch_size ) self._qaic_model.update_decode_input(outputs, position_ids, generation_len) @@ -1177,6 +1367,7 @@ def generate_stream_tokens( def generate( self, prompt: List[str], + images: Optional[List[str]] = None, generation_len: Optional[int] = None, stream: bool = True, automation: Optional[bool] = False, @@ -1197,7 +1388,7 @@ def generate( if self._full_batch_size is not None: logger.warning("Streamer is currently unavailable for continuous batch execution.") perf_metrics, generated_texts = self._continuous_batching_execution( - prompt, generation_len, prompt_to_lora_id_mapping + prompt, images, generation_len, prompt_to_lora_id_mapping ) else: if stream: diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..b7b951101 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -820,7 +820,7 @@ def forward(self, pixel_values): ) vision_flat = image_features.view(-1, image_features.size(-1)) projected_vision_flat = self.model.multi_modal_projector(vision_flat) - return projected_vision_flat + return projected_vision_flat # , pixel_values # This wrapper utilizes the 'vision_embeds', which contains vision embeddings, and an 'image_idx' index starting at 0. @@ -836,7 +836,15 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + batch_index: Optional[torch.LongTensor] = None, + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +854,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -893,6 +905,9 @@ def get_specializations( ctx_len: int, img_size: int, kv_offload: bool = False, + continuous_batching: bool = False, + kv_cache_batch_size: Optional[int] = None, + full_batch_size: Optional[int] = None, **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) @@ -941,28 +956,42 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -971,18 +1000,22 @@ def get_specializations( specializations["lang"] = lang return specializations, compiler_options else: + lang[0].pop("vision_size") + lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["vision_embeds"] = {0: "vision_size"} + if continuous_batching: + lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "max_num_tiles", 2: "img_size", 3: "img_size"} - pkv_dynamic_axes = {0: "batch_size"} + pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size"} for i in range(self.language_model.config.num_hidden_layers): # switch between chunk_ctx_len and ctx_len for RoPE and NoPE layers. if int((i + 1) % 4 != 0): @@ -1011,6 +1044,7 @@ def get_output_names(self, kv_offload: bool = False): output_names = {} if kv_offload: + # vision_output_names.insert(1, "pixel_values_RetainedState") lang_output_names.insert(1, "vision_embeds_RetainedState") lang_output_names.insert(2, "image_idx_output") output_names["vision"] = vision_output_names @@ -1045,7 +1079,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1090,10 +1124,14 @@ def get_dummy_inputs(self, kv_offload: bool = False): .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) ) lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + # Add data for KV past_key_values = self.get_dummy_pkv_cache( config=self.language_model.config, - batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + batch_size=fbs if continuous_batching else bs, seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) @@ -1102,6 +1140,8 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cdacc7760..1fd3f58f7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn from transformers import ( + AutoImageProcessor, AutoModel, AutoModelForCausalLM, AutoModelForImageTextToText, @@ -851,6 +852,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching, **kwargs, ): """ @@ -874,6 +876,7 @@ def __init__( self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @property @@ -973,8 +976,8 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(kv_offload=True, continuous_batching=self.continuous_batching) + dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True, continuous_batching=self.continuous_batching) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1063,14 +1066,20 @@ def compile( If `full_batch_size`, `kv_cache_batch_size`, or `num_speculative_tokens` are not None. If both `skip_lang` and `skip_vision` are True. """ - if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]): + if skip_lang and skip_vision: + raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + + if kv_cache_batch_size and not full_batch_size: raise ValueError( - f"Expected 'full_batch_size', 'kv_cache_batch_size', 'num_speculative_tokens' to be None but got: " - f"full_batch_size={full_batch_size}, kv_cache_batch_size={kv_cache_batch_size}, num_speculative_tokens={num_speculative_tokens}, " + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." ) - if skip_lang and skip_vision: - raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") + # Infer kv_cache_batch_size if not provided + kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size output_names = self.model.get_output_names(kv_offload=True) @@ -1080,6 +1089,9 @@ def compile( ctx_len=ctx_len, img_size=img_size, kv_offload=True, + continuous_batching=self.continuous_batching, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, **compiler_options, ) @@ -1147,7 +1159,11 @@ def compile( def generate( self, - inputs: torch.Tensor, + inputs: Optional[torch.Tensor] = None, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer] = None, + processor: Optional[AutoImageProcessor] = None, + images: List[str] = None, + prompts: List[str] = None, streamer: Optional[TextStreamer] = None, device_ids: List[int] = None, runtime_ai100: bool = True, @@ -1187,6 +1203,17 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + if (processor and images) or (tokenizer and prompts): + return QEfficient.cloud_ai_100_exec_kv( + tokenizer=tokenizer, + processor=processor, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + images=images, + prompt=prompts, + device_id=device_ids, + generation_len=generation_len, + ) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1314,9 +1341,7 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -1328,7 +1353,7 @@ def kv_offload_generate( outputs = lang_session.run(chunk_inputs) chunk_inputs["image_idx"] = outputs["image_idx_output"] - prefill_time = perf_counter() - prefill_start + vision_end - vision_start + prefill_time = perf_counter() - lang_start + vision_end - vision_start # Skip inputs/outputs again lang_session.skip_buffers( [ @@ -1909,7 +1934,7 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, continuous_batching: bool = False, **kwargs): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -1930,13 +1955,19 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) The wrapped model instance, configured for either dual or single QPC. """ if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC(model, continuous_batching, **kwargs) else: return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -1971,12 +2002,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - if kwargs.pop("continuous_batching", None): - NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if continuous_batching and not kv_offload: + NotImplementedError("Continuous batching is not supported for kv_offload = False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + continuous_batching=continuous_batching, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} @@ -2681,8 +2718,8 @@ def generate( raise TypeError("Please run compile API first!") generation_len = kwargs.pop("generation_len", None) return QEfficient.cloud_ai_100_exec_kv( - tokenizer, - self.qpc_path, + tokenizer=tokenizer, + lang_qpc_path=self.qpc_path, prompt=prompts, device_id=device_id, generation_len=generation_len, diff --git a/examples/llama4_CB_example_vision_lang.py b/examples/llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..ebe65bf82 --- /dev/null +++ b/examples/llama4_CB_example_vision_lang.py @@ -0,0 +1,65 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0, 1, 2, 3], + generation_len=100, +)