diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py
index c4f5a7bbd..269ccb0be 100644
--- a/QEfficient/customop/ctx_scatter_gather.py
+++ b/QEfficient/customop/ctx_scatter_gather.py
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
-def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
- ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
+def CtxGather(
+ data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
+) -> onnxscript.FLOAT:
+ # Create a shape tensor based on comp_ctx_len
+ shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
+
+ # Directly use the shape tensor without validation
+ ctx_indices = ops.Expand(ctx_indices, shape_tensor)
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)
@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
"""
@staticmethod
- def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
+ def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
pass
@staticmethod
- def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
- return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
+ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
+ return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py
index 75d9a12ef..cc9693716 100644
--- a/QEfficient/customop/ctx_scatter_gather_cb.py
+++ b/QEfficient/customop/ctx_scatter_gather_cb.py
@@ -97,16 +97,20 @@ def symbolic(
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB(
- data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
+ data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
- ctx_len = ops.Gather(ops.Shape(data), [2])
+ # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
+ ctx_len = ops.Reshape(comp_ctx_len, [1])
# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
- exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
+ # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
+ exp_shape = ops.Concat(
+ ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
+ )
# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
class CtxGatherFuncCB(torch.autograd.Function):
@staticmethod
- def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
+ def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
pass
@staticmethod
- def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
- return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
+ def symbolic(
+ g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
+ ) -> torch.Value:
+ return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index cf9cbcacc..1ed3f793b 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
prompts_txt_file_path: Optional[str] = None,
device_id: Optional[List[int]] = None,
generation_len: Optional[int] = None,
+ comp_ctx_lengths_prefill: Optional[List[int]] = None,
+ comp_ctx_lengths_decode: Optional[List[int]] = None,
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
@@ -382,6 +384,8 @@ def cloud_ai_100_exec_kv(
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
@@ -424,6 +428,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
+ comp_ctx_lengths_prefill: Optional[List[int]] = None,
+ comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
@@ -433,6 +439,8 @@ def __init__(
sampling_params: Optional[Dict[str, Any]] = None,
) -> None:
self._ctx_len = ctx_len
+ self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
+ self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm
self.return_pdfs = return_pdfs
@@ -791,7 +799,16 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
+ if self.comp_ctx_lengths_prefill is not None:
+ self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ prefill_ccl_id = 0
+ inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
for i in range(num_chunks):
+ if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
+ prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
+ inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -810,6 +827,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
generation_len,
)
+ def initialize_ccl(self, decode_inputs):
+ self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
+ max_position_id = np.max(decode_inputs["position_ids"])
+ ccl_id_initial = 0
+ ccl_id = ccl_id_initial
+ for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
+ if max_position_id < self.comp_ctx_lengths_decode[i]:
+ ccl_id = i
+ break
+
+ return ccl_id, max_ccl_id
+
def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""
Runs continuous batching decode for the given prompt queue and generation length.
@@ -841,6 +871,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
# Prepare decode inputs inputs.
decode_inputs = self.prepare_decode_inputs()
+ if self.comp_ctx_lengths_decode is not None:
+ ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
+ decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
+
while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)
@@ -878,6 +912,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
batch_id_map[decode_batch_id]
]
+ if self.comp_ctx_lengths_decode is not None:
+ ###Recalculate ccl_id based on position ids###
+ # Determine the maximum value of position_ids across all batch elements
+ max_position_id = np.max(decode_inputs["position_ids"])
+
+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
+ ccl_id_initial = self.prefill_ccl_len
+ ccl_id = ccl_id_initial
+ for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
+ if max_position_id < self.comp_ctx_lengths_decode[i]:
+ ccl_id = i
+ break
+ decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
+
else:
current_decode_ongoing[decode_batch_id] = False
else:
@@ -890,6 +938,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
if self.include_sampler:
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
+ if self.comp_ctx_lengths_decode is not None:
+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
+ if (
+ decode_inputs["position_ids"][decode_batch_id, -1]
+ >= self.comp_ctx_lengths_decode[ccl_id] - 1
+ ):
+ ccl_id = min(ccl_id + 1, max_ccl_id)
+ decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
+
generated_id_current_index[decode_batch_id] += 1
return decode_pause_time
@@ -914,7 +971,18 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0
+
+ if self.comp_ctx_lengths_decode is not None:
+ ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
+ decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
+
+ cache_index = np.max(decode_inputs["position_ids"])
for num_token in range(1, generation_len):
+ if self.comp_ctx_lengths_decode is not None:
+ if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
+ ccl_id = min(ccl_id + 1, max_ccl_id)
+ decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
+
if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)
@@ -926,6 +994,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
# Prepare inputs for next iteration
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
decode_inputs["position_ids"][:, -1] += 1
+ cache_index += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
if self.include_sampler:
@@ -975,6 +1044,8 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
+ comp_ctx_lengths_prefill: Optional[List[int]] = None,
+ comp_ctx_lengths_decode: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
@@ -988,6 +1059,8 @@ def __init__(
qpc_path=qpc_path,
full_batch_size=full_batch_size,
ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
device_id=device_id,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
@@ -999,6 +1072,8 @@ def __init__(
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._ctx_len = ctx_len
+ self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
+ self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None
diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py
index 16767fbe2..ed9f9b142 100644
--- a/QEfficient/transformers/cache_utils.py
+++ b/QEfficient/transformers/cache_utils.py
@@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
+ comp_ctx_len = cache_kwargs.get("CCL")
+
ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
@@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs):
else:
invalid_idx_value = 0
+ ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+ invalid_mask = ctx_indices > gather_limit
+
+ invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
- k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
- v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
+ k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
+ v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
- k_out = CtxGatherFunc.apply(k_out, ctx_indices)
- v_out = CtxGatherFunc.apply(v_out, ctx_indices)
-
+ k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
+ v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
@@ -144,6 +150,7 @@ def update(
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
+ comp_ctx_len = cache_kwargs.get("CCL")
# Scatter
if batch_index is not None:
@@ -166,23 +173,28 @@ def update(
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
# Gather
- ctx_len = k_out.shape[2]
+ ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
- invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
+ ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+ invalid_mask = ctx_indices > gather_limit
+
+ invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+
if batch_index is not None:
- k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
- v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
+ k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
+ v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
- k_out = CtxGatherFunc.apply(k_out, ctx_indices)
- v_out = CtxGatherFunc.apply(v_out, ctx_indices)
+ k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
+ v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out
@@ -344,6 +356,8 @@ def update(
else:
position_ids = cache_kwargs.get("position_ids")
sliding_window_pattern = cache_kwargs.get("sliding_window_pattern")
+ comp_ctx_len = cache_kwargs.get("CCL")
+
is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern))
layer_ctx_len = self.key_cache[layer_idx].shape[2]
kv_position_ids = torch.where(
@@ -369,20 +383,26 @@ def update(
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
- invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
+
+ ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+ invalid_mask = ctx_indices > gather_limit
+
+ invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
+ rolling_indices = rolling_indices[:comp_ctx_len]
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
- k_out = CtxGatherFunc.apply(k_out, final_indices)
- v_out = CtxGatherFunc.apply(v_out, final_indices)
+ k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out
@@ -443,6 +463,8 @@ def update(
else:
position_ids = cache_kwargs.get("position_ids")
+ comp_ctx_len = cache_kwargs.get("CCL")
+
is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx]))
# Update the position_ids to handle the sliding window
@@ -470,21 +492,27 @@ def update(
ctx_len = min(layer_ctx_len, k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1)
- invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
+
+ ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+ invalid_mask = ctx_indices > gather_limit
+
+ invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
# Rolling indices for sliding window
all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1
rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices)
+ rolling_indices = rolling_indices[:comp_ctx_len]
final_indices = torch.where(
(is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices
)
- k_out = CtxGatherFunc.apply(k_out, final_indices)
- v_out = CtxGatherFunc.apply(v_out, final_indices)
+ k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len)
+ v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len)
ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out)
return k_out, v_out
diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py
index 0cefbcfee..552c44442 100644
--- a/QEfficient/transformers/models/gemma/modeling_gemma.py
+++ b/QEfficient/transformers/models/gemma/modeling_gemma.py
@@ -138,6 +138,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -156,8 +157,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -190,6 +199,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -222,6 +232,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -260,6 +271,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -321,6 +333,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -363,6 +376,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -386,6 +400,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py
index 173da1798..0fdf397cc 100644
--- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py
+++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py
@@ -144,6 +144,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -162,8 +163,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -196,6 +205,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -228,6 +238,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -268,6 +279,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -340,6 +352,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -383,6 +396,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -406,6 +420,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
index 851bb9436..ff76b73ed 100644
--- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py
+++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py
@@ -6,7 +6,7 @@
# -----------------------------------------------------------------------------
import copy
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
import torch
from torch import nn
@@ -215,6 +215,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -247,6 +248,8 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
@@ -255,6 +258,7 @@ def forward(
"position_ids": position_ids,
"is_sliding": self.is_sliding,
"sliding_window_pattern": self.config.sliding_window_pattern,
+ "CCL": attention_mask.shape[-1],
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -299,6 +303,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -325,6 +330,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -365,6 +371,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -445,6 +452,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -482,6 +490,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -541,6 +550,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
@@ -607,7 +617,7 @@ def __init__(self, model):
self.language_model = self.model.language_model
self.config = self.model.config
- def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
B, N, C = inputs_embeds.shape
selected = input_ids == self.model.config.image_token_index
@@ -618,7 +628,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
outputs = self.model.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values
@@ -631,7 +645,7 @@ def get_qeff_vision_encoder(self):
def get_qeff_language_decoder(self):
return QEffGemma3DecoderWrapper(self)
- def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values):
+ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths):
image_features = self.get_image_features(pixel_values=pixel_values)
inputs_embeds = self.get_input_embeddings()(input_ids)
B, N, C = inputs_embeds.shape
@@ -643,7 +657,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
outputs = self.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return outputs.logits, pixel_values, image_idx, outputs.past_key_values
@@ -654,6 +672,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -674,24 +694,55 @@ def get_specializations(
"ctx_len": ctx_len,
}
]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "sliding_window": self.language_model.config.sliding_window,
- "img_size": img_size,
- "mm_tokens_per_image": mm_tokens_per_image,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "sliding_window": self.language_model.config.sliding_window,
- "img_size": img_size,
- "mm_tokens_per_image": mm_tokens_per_image,
- },
- ]
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "sliding_window": self.language_model.config.sliding_window,
+ "img_size": img_size,
+ "mm_tokens_per_image": mm_tokens_per_image,
+ }
+ )
+
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "sliding_window": self.language_model.config.sliding_window,
+ "img_size": img_size,
+ "mm_tokens_per_image": mm_tokens_per_image,
+ }
+ )
+
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "sliding_window": self.language_model.config.sliding_window,
+ "img_size": img_size,
+ "mm_tokens_per_image": mm_tokens_per_image,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "sliding_window": self.language_model.config.sliding_window,
+ "img_size": img_size,
+ "mm_tokens_per_image": mm_tokens_per_image,
+ },
+ ]
+
specializations = {}
if kv_offload:
@@ -701,7 +752,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
# Define dynamic axes
vision_dynamic_axes = {}
lang_dynamic_axes = {}
@@ -726,6 +777,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
)
lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
@@ -774,7 +828,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
past_key_values.append(pkv)
return past_key_values
- def get_dummy_inputs(self, kv_offload: bool = False):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 896)
else:
@@ -820,6 +874,9 @@ def get_dummy_inputs(self, kv_offload: bool = False):
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py
index 13b308547..2ae94aac8 100644
--- a/QEfficient/transformers/models/granite/modeling_granite.py
+++ b/QEfficient/transformers/models/granite/modeling_granite.py
@@ -129,6 +129,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -147,8 +148,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -173,6 +182,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -228,6 +238,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -269,6 +280,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -321,6 +333,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py
index 8f840b4b4..c08ddf2b8 100644
--- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py
@@ -122,6 +122,7 @@ def forward(
position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
@@ -139,6 +140,8 @@ def forward(
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
@@ -146,6 +149,7 @@ def forward(
"cache_position": cache_position,
"batch_index": batch_index,
"position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -183,6 +187,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -268,6 +273,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -279,6 +285,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@@ -477,6 +484,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -530,6 +538,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py
index b6fb9fd38..a1b19ca54 100644
--- a/QEfficient/transformers/models/internvl/modeling_internvl.py
+++ b/QEfficient/transformers/models/internvl/modeling_internvl.py
@@ -5,6 +5,8 @@
#
# -----------------------------------------------------------------------------
+from typing import List
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -31,7 +33,7 @@ def __init__(self, model):
self.config = self.model.language_model.config
self.language_model = self.model.language_model
- def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths):
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
@@ -44,7 +46,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.model.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values
@@ -63,6 +69,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -92,24 +100,54 @@ def get_specializations(
"img_size": img_size,
}
]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "num_patches": num_patches,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "num_patches": num_patches,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- ]
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "num_patches": num_patches,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "num_patches": num_patches,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "num_patches": num_patches,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "num_patches": num_patches,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ ]
specializations = {}
@@ -120,7 +158,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
# Define dynamic axes
vision_dynamic_axes = {}
lang_dynamic_axes = {}
@@ -134,6 +172,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
for kv in ["key", "value"]:
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
@@ -161,7 +201,7 @@ def get_output_names(self, kv_offload: bool = False):
return lang_output_names
return output_names
- def get_dummy_inputs(self, kv_offload: bool = False):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE)
else:
@@ -222,6 +262,9 @@ def get_dummy_inputs(self, kv_offload: bool = False):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
@@ -232,7 +275,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
return inputs
- def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values, comp_ctx_lengths):
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vision_embeds = self.extract_feature(pixel_values)
B, N, C = input_embeds.shape
@@ -246,7 +289,11 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx)
diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py
index a285f00dc..a23bed857 100644
--- a/QEfficient/transformers/models/llama/modeling_llama.py
+++ b/QEfficient/transformers/models/llama/modeling_llama.py
@@ -133,6 +133,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
@@ -157,8 +158,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -191,6 +200,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -207,6 +217,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -244,6 +255,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -297,6 +309,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -340,6 +353,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -363,6 +377,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py
index 4b957ebec..78074a51b 100644
--- a/QEfficient/transformers/models/llama4/modeling_llama4.py
+++ b/QEfficient/transformers/models/llama4/modeling_llama4.py
@@ -464,6 +464,7 @@ def forward(
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -501,6 +502,8 @@ def forward(
if past_key_value is not None:
chunk_position_ids = position_ids
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
if self.use_rope:
chunk_position_ids = torch.where(
@@ -508,7 +511,11 @@ def forward(
)
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids}
+ cache_kwargs = {
+ "batch_index": batch_index,
+ "position_ids": chunk_position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -542,6 +549,7 @@ def forward(
chunk_causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
@@ -565,6 +573,7 @@ def forward(
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -618,6 +627,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -682,6 +692,7 @@ def forward(
chunk_causal_mask=chunk_causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -726,6 +737,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -749,6 +761,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
@@ -831,7 +844,7 @@ def __init__(self, model):
self.language_model = self.model.language_model
self.config = self.model.config
- def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids)
selected = input_ids == self.model.config.image_token_index
indices1 = selected.to(torch.int64).cumsum(1) - 1
@@ -841,7 +854,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds)
outputs = self.model.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_idx, next_idx, image_idx)
@@ -855,7 +872,7 @@ def get_qeff_vision_encoder(self):
def get_qeff_language_decoder(self):
return QEffLlama4DecoderWrapper(self)
- def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values):
+ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
vision_feature_layer = self.config.vision_config.vision_feature_layer
vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy
@@ -875,7 +892,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val
image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds)
outputs = self.language_model(
- inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
+ inputs_embeds=inputs_embeds,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
+ use_cache=True,
)
next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_idx, next_idx, image_idx)
@@ -887,6 +908,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -936,28 +959,62 @@ def get_specializations(
"img_size": img_size,
}
]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "max_num_tiles": max_num_tiles,
- "img_size": img_size,
- "vision_size": vision_size,
- "chunk_length": prefill_seq_len,
- "chunk_ctx_len": chunk_ctx_len,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "max_num_tiles": max_num_tiles,
- "img_size": img_size,
- "vision_size": vision_size,
- "chunk_length": prefill_seq_len,
- "chunk_ctx_len": chunk_ctx_len,
- },
- ]
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "max_num_tiles": max_num_tiles,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ "chunk_length": prefill_seq_len,
+ "chunk_ctx_len": chunk_ctx_len,
+ }
+ )
+
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "max_num_tiles": max_num_tiles,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ "chunk_length": prefill_seq_len,
+ "chunk_ctx_len": chunk_ctx_len,
+ }
+ )
+
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "max_num_tiles": max_num_tiles,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ "chunk_length": prefill_seq_len,
+ "chunk_ctx_len": chunk_ctx_len,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "max_num_tiles": max_num_tiles,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ "chunk_length": prefill_seq_len,
+ "chunk_ctx_len": chunk_ctx_len,
+ },
+ ]
specializations = {}
@@ -968,7 +1025,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
# Define dynamic axes
vision_dynamic_axes = {}
lang_dynamic_axes = {}
@@ -988,6 +1045,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
for kv in ["key", "value"]:
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
@@ -1040,7 +1100,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len):
past_key_values.append(pkv)
return past_key_values
- def get_dummy_inputs(self, kv_offload: bool = False):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
if vis_cfg := getattr(self.config, "vision_config", None):
img_size = getattr(vis_cfg, "image_size", 336)
else:
@@ -1097,6 +1157,9 @@ def get_dummy_inputs(self, kv_offload: bool = False):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32))
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
index f5e60c5de..ad4f18e11 100644
--- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
+++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py
@@ -89,6 +89,7 @@ def forward(
hidden_states: torch.Tensor,
position_ids: torch.LongTensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
attention_mask: torch.Tensor = None,
batch_index: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
@@ -106,8 +107,14 @@ def forward(
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
- cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
+ cache_kwargs = {
+ "position_ids": position_ids,
+ "batch_index": batch_index,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
@@ -157,6 +164,7 @@ def forward(
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
past_key_values,
+ comp_ctx_lengths,
causal_mask,
batch_index: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -168,6 +176,7 @@ def forward(
hidden_states=hidden_states,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
attention_mask=causal_mask,
batch_index=batch_index,
)
@@ -203,12 +212,18 @@ def __init__(self, config: QEffLlamaSwiftKVConfig):
self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def _run_swiftkv_layers(
- self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index
+ self,
+ hidden_states: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values,
+ comp_ctx_lengths,
+ causal_mask,
+ batch_index,
) -> torch.Tensor:
for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers):
layer = self.layers[layer_idx]
hidden_states, past_key_values = layer(
- hidden_states, position_ids, past_key_values, causal_mask, batch_index
+ hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index
)
hidden_states = self.norm(hidden_states)
@@ -293,6 +308,7 @@ def forward(
input_ids: Optional[torch.Tensor],
position_ids: torch.Tensor,
past_key_values: List[torch.Tensor],
+ comp_ctx_lengths: Optional[torch.LongTensor],
batch_index: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
@@ -332,6 +348,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=False,
use_cache=True,
@@ -378,7 +395,7 @@ def forward(
causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :]
hidden_states, next_decoder_cache = self._run_swiftkv_layers(
- hidden_states, position_ids, past_key_values, causal_mask, batch_index
+ hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index
)
# We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction
# we only need the last valid pos_indices hidden_states.
@@ -410,9 +427,12 @@ def forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
):
- hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index)
+ hidden_states, output_past_key_values = self.model(
+ input_ids, position_ids, past_key_values, comp_ctx_lengths, batch_index
+ )
logits = self.lm_head(hidden_states)
return CausalLMOutputWithPast(
loss=None,
diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py
index 99384cb55..26e65821c 100644
--- a/QEfficient/transformers/models/llava/modeling_llava.py
+++ b/QEfficient/transformers/models/llava/modeling_llava.py
@@ -5,6 +5,8 @@
#
# -----------------------------------------------------------------------------
+from typing import List
+
import torch
import torch.nn as nn
import torch.utils.checkpoint
@@ -50,7 +52,7 @@ def __init__(self, model):
self.config = self.model.config
self.language_model = self.model.language_model
- def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
mask = input_ids == self.model.config.image_token_index
@@ -64,6 +66,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
@@ -77,7 +80,7 @@ def get_qeff_vision_encoder(self):
def get_qeff_language_decoder(self):
return QEFFLlavaDecoderWrapper(self)
- def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values):
+ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.get_input_embeddings()(input_ids)
# Image features
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -103,12 +106,13 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
)
next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx)
return outputs.logits, pixel_values, image_idx, outputs.past_key_values
- def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs):
num_layers = self.config.text_config.num_hidden_layers
num_key_value_heads = self.config.text_config.num_key_value_heads
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
@@ -138,6 +142,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
)
)
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1)
+
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
@@ -154,6 +162,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -175,24 +185,55 @@ def get_specializations(
"img_size": img_size,
}
]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "max_num_images": max_num_images,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "max_num_images": max_num_images,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- ]
+
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ ]
+
specializations = {}
if kv_offload:
@@ -202,7 +243,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
# Define dynamic axes
num_layers = self.config.text_config.num_hidden_layers
@@ -218,6 +259,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
index 23434fc18..4fa63d7ca 100755
--- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py
+++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py
@@ -6,6 +6,8 @@
# -----------------------------------------------------------------------------
+from typing import List
+
import numpy as np
import torch
import torch.nn as nn
@@ -120,7 +122,7 @@ def __init__(self, model):
self.config = self.model.config
self.language_model = self.model.language_model
- def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
+ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
mask = input_ids == self.config.image_token_index
@@ -135,6 +137,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
)
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values
@@ -147,7 +150,7 @@ def get_qeff_vision_encoder(self):
def get_qeff_language_decoder(self):
return QEffLlavaNextDecoderWrapper(self)
- def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs):
num_layers = self.config.text_config.num_hidden_layers
num_key_value_heads = self.config.text_config.num_key_value_heads
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
@@ -210,6 +213,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
)
)
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1)
+
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
@@ -225,6 +232,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -278,30 +287,67 @@ def get_specializations(
"img_size": img_size,
}
]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "image_size_height": image_size_height,
- "image_size_width": image_size_width,
- "num_patches": num_patches,
- "max_num_images": max_num_images,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "image_size_height": image_size_height,
- "image_size_width": image_size_width,
- "num_patches": num_patches,
- "max_num_images": max_num_images,
- "img_size": img_size,
- "vision_size": vision_size,
- },
- ]
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "image_size_height": image_size_height,
+ "image_size_width": image_size_width,
+ "num_patches": num_patches,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+
+ # Remaining elements use comp_ctx_lengths[1:] in a loop
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "image_size_height": image_size_height,
+ "image_size_width": image_size_width,
+ "num_patches": num_patches,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ }
+ )
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "image_size_height": image_size_height,
+ "image_size_width": image_size_width,
+ "num_patches": num_patches,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "image_size_height": image_size_height,
+ "image_size_width": image_size_width,
+ "num_patches": num_patches,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ "vision_size": vision_size,
+ },
+ ]
+
specializations = {}
if kv_offload:
specializations["vision"] = vision
@@ -310,7 +356,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
# Define dynamic axes
num_layers = self.config.text_config.num_hidden_layers
vision_dynamic_axes = {
@@ -325,6 +371,10 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
for i in range(num_layers):
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
+
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py
index 60b1c929d..7a330a9cf 100644
--- a/QEfficient/transformers/models/mistral/modeling_mistral.py
+++ b/QEfficient/transformers/models/mistral/modeling_mistral.py
@@ -142,6 +142,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
@@ -166,8 +167,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -200,6 +209,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -230,6 +240,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -268,6 +279,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -332,6 +344,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -375,6 +388,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -398,6 +412,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py
index ef51c3421..ed30c50b2 100644
--- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py
+++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py
@@ -138,6 +138,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -161,12 +162,15 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -254,6 +258,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
@@ -295,6 +300,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -339,6 +345,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -404,6 +411,7 @@ def forward(
position_ids=position_ids,
batch_index=batch_index,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
use_cache=use_cache,
@@ -454,6 +462,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -507,6 +516,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py
index 8a98c4c96..f62c2d4d0 100644
--- a/QEfficient/transformers/models/mllama/modeling_mllama.py
+++ b/QEfficient/transformers/models/mllama/modeling_mllama.py
@@ -130,6 +130,7 @@ def forward(
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
@@ -208,6 +209,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_embeddings: torch.Tensor = None,
output_attentions: bool = False,
@@ -239,12 +241,15 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -289,6 +294,7 @@ def forward(
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -327,6 +333,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -365,6 +372,7 @@ def forward(
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
@@ -383,13 +391,15 @@ def forward(
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# if we have a new image + new tokens, we only computed key_states on that new image
# we still update the cross key states, past_image, new_image. And use it!
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
- {"batch_index": batch_index, "position_ids": position_ids},
+ {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]},
)
elif past_key_value is not None:
key_states, value_states = (
@@ -440,6 +450,7 @@ def forward(
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -454,6 +465,7 @@ def forward(
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
cache_position=cache_position,
@@ -628,6 +640,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.FloatTensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
@@ -706,6 +719,7 @@ def forward(
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@@ -746,6 +760,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.LongTensor] = None,
@@ -774,6 +789,7 @@ def forward(
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
@@ -850,6 +866,7 @@ def forward(
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -916,6 +933,7 @@ def forward(
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
@@ -928,7 +946,7 @@ def forward(
return outputs.logits, image_idx, outputs.past_key_values, pixel_values
- def get_dummy_inputs(self, kv_offload: bool = False):
+ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
CTX_LEN = constants.ONNX_EXPORT_CTX_LEN
@@ -993,6 +1011,10 @@ def get_dummy_inputs(self, kv_offload: bool = False):
lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache()
lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1)
+
+ if comp_ctx_lengths is not None:
+ lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+
inputs = {}
if kv_offload:
@@ -1009,6 +1031,8 @@ def get_specializations(
prefill_seq_len: int,
ctx_len: int,
img_size: int,
+ comp_ctx_lengths_prefill: List[int] = None,
+ comp_ctx_lengths_decode: List[int] = None,
kv_offload: bool = False,
**compiler_options,
):
@@ -1023,22 +1047,53 @@ def get_specializations(
logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config")
vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}]
- lang = [
- {
- "batch_size": batch_size,
- "seq_len": prefill_seq_len,
- "ctx_len": ctx_len,
- "max_num_images": max_num_images,
- "img_size": img_size,
- },
- {
- "batch_size": batch_size,
- "seq_len": "1",
- "ctx_len": ctx_len,
- "max_num_images": max_num_images,
- "img_size": img_size,
- },
- ]
+
+ if comp_ctx_lengths_prefill is not None:
+ lang = []
+
+ for i in range(0, len(comp_ctx_lengths_prefill)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_prefill[i],
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ }
+ )
+
+ # Remaining elements use comp_ctx_lengths[1:] in a loop
+ for i in range(0, len(comp_ctx_lengths_decode)):
+ lang.append(
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "comp_ctx_lengths": comp_ctx_lengths_decode[i],
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ }
+ )
+
+ else:
+ lang = [
+ {
+ "batch_size": batch_size,
+ "seq_len": prefill_seq_len,
+ "ctx_len": ctx_len,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ },
+ {
+ "batch_size": batch_size,
+ "seq_len": "1",
+ "ctx_len": ctx_len,
+ "max_num_images": max_num_images,
+ "img_size": img_size,
+ },
+ ]
+
specializations = {}
if kv_offload:
@@ -1048,7 +1103,7 @@ def get_specializations(
else:
return lang, compiler_options
- def get_onnx_dynamic_axes(self, kv_offload: bool = False):
+ def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False):
txt_cfg = self.config.get_text_config()
num_hidden_layers = txt_cfg.num_hidden_layers
cross_attention_layers = txt_cfg.cross_attention_layers
@@ -1073,6 +1128,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}
+ if comp_ctx_lengths is not None:
+ lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index cd1c13a00..424549b5e 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -55,6 +55,7 @@
constants,
get_padding_shape_from_config,
)
+from QEfficient.utils.check_ccl_specializations import process_ccl_specializations
from QEfficient.utils.logging_utils import logger
@@ -872,6 +873,15 @@ def __init__(
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
self.model = model
self.config = model.config
+
+ self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+ if self.comp_ctx_lengths_prefill:
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
+ )
+
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
self.input_shapes, self.output_names = None, None
@@ -917,8 +927,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
logger.warning("Updating low_cpu_mem_usage=False")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
+ comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+
+ if comp_ctx_lengths_prefill:
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
+ )
+
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
- return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+ return cls(
+ model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ **kwargs,
+ )
@property
def onnx_path(self):
@@ -973,8 +999,8 @@ def export(
List[str]
A list containing the paths to the generated ONNX graph files for both components.
"""
- inputs = self.model.get_dummy_inputs(kv_offload=True)
- dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True)
+ inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode, kv_offload=True)
+ dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode, kv_offload=True)
output_names = self.model.get_output_names(kv_offload=True)
self.vision_model.export(
@@ -1078,6 +1104,8 @@ def compile(
batch_size=batch_size,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
img_size=img_size,
kv_offload=True,
**compiler_options,
@@ -1314,13 +1342,26 @@ def kv_offload_generate(
lang_session.set_buffers(vision_outputs)
+ if self.comp_ctx_lengths_prefill is not None:
+ list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ prefill_ccl_id = 0
+ lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
# Prepare inputs for prefill
chunk_inputs = lang_inputs.copy()
prefill_start = perf_counter()
# Run prefill
+
chunk_inputs = lang_inputs.copy()
for i in range(num_chunks):
+ if (
+ self.comp_ctx_lengths_prefill is not None
+ and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]
+ ):
+ prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
+ chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = lang_inputs["position_ids"][
:, i * prefill_seq_len : (i + 1) * prefill_seq_len
@@ -1350,8 +1391,26 @@ def kv_offload_generate(
streamer.put(lang_inputs["input_ids"][0])
# Decode loop
+
+ if self.comp_ctx_lengths_decode is not None:
+ max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
+ list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ max_position_id = np.max(lang_inputs["position_ids"])
+ ccl_id_initial = 0
+ ccl_id = ccl_id_initial
+ for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
+ if max_position_id < self.comp_ctx_lengths_decode[i]:
+ ccl_id = i
+ break
+ lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id]
+
decode_start = perf_counter()
for num_token in range(1, generation_len):
+ if self.comp_ctx_lengths_decode is not None:
+ if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1:
+ ccl_id = min(ccl_id + 1, max_ccl_id)
+ lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id]
+
outputs = lang_session.run(lang_inputs)
# Prepare inputs for next iteration
@@ -1422,6 +1481,15 @@ def __init__(
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
super().__init__(model, **kwargs)
+ self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+
+ if self.comp_ctx_lengths_prefill:
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
+ )
+
# to handle internvl models
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
self.model.config.llm_config.use_cache = True
@@ -1465,6 +1533,16 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
+
+ comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+
+ if comp_ctx_lengths_prefill:
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
+ )
+
from transformers import AutoConfig
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
@@ -1472,7 +1550,14 @@ def from_pretrained(
config.vision_config.use_flash_attn = "false"
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs)
- return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+ return cls(
+ model,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ **kwargs,
+ )
def export(
self,
@@ -1494,8 +1579,8 @@ def export(
str
Path to the generated ONNX graph file.
"""
- inputs = self.model.get_dummy_inputs()
- dynamic_axes = self.model.get_onnx_dynamic_axes()
+ inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode)
+ dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode)
output_names = self.model.get_output_names()
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
@@ -1577,6 +1662,8 @@ def compile(
batch_size=batch_size,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
img_size=img_size,
**compiler_options,
)
@@ -1761,12 +1848,24 @@ def cloud_ai_100_generate(
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
inputs["image_idx"] = np.array([[0]])
+ if self.comp_ctx_lengths_prefill is not None:
+ list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
+ prefill_ccl_id = 0
+ inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
qpc_session.activate()
chunk_inputs = inputs.copy()
prefill_start = perf_counter()
# Run prefill
for i in range(num_chunks):
+ if (
+ self.comp_ctx_lengths_prefill is not None
+ and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]
+ ):
+ prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
+ chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
+
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
outputs = qpc_session.run(chunk_inputs)
@@ -1790,8 +1889,26 @@ def cloud_ai_100_generate(
inputs.pop("pixel_values")
# Decode loop
+
+ if self.comp_ctx_lengths_decode is not None:
+ list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
+ max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
+ max_position_id = np.max(inputs["position_ids"])
+ ccl_id_initial = 0
+ ccl_id = ccl_id_initial
+ for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
+ if max_position_id < self.comp_ctx_lengths_decode[i]:
+ ccl_id = i
+ break
+ inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id]
+
decode_start = perf_counter()
for num_token in range(1, generation_len):
+ if self.comp_ctx_lengths_decode is not None:
+ if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1:
+ ccl_id = min(ccl_id + 1, max_ccl_id)
+ inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id]
+
outputs = qpc_session.run(inputs)
# Prepare inputs for next iteration
inputs["input_ids"] = outputs["logits"].argmax(2)
@@ -1929,6 +2046,9 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs)
Union[_QEffAutoModelForImageTextToTextDualQPC, _QEFFAutoModelForImageTextToTextSingleQPC]
The wrapped model instance, configured for either dual or single QPC.
"""
+ self.comp_ctx_lengths_prefill = kwargs.get("comp_ctx_lengths_prefill", None)
+ self.comp_ctx_lengths_decode = kwargs.get("comp_ctx_lengths_decode", None)
+
if kv_offload:
return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs)
else:
@@ -1975,8 +2095,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona
NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
+
+ comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+
+ if comp_ctx_lengths_prefill:
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
+ )
+
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
- return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+ return cls(
+ model,
+ kv_offload=kv_offload,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ **kwargs,
+ )
MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText}
@@ -2072,6 +2210,15 @@ def __init__(
self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
self.is_tlm = transformed
+ self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+
+ if self.comp_ctx_lengths_prefill:
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
+ self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len
+ )
+
self.hash_params["qeff_auto_class"] = self.__class__.__name__
# ---Sampling---
@@ -2166,6 +2313,14 @@ def from_pretrained(
kv_offload = kwargs.pop("kv_offload", None)
+ comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None)
+ comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None)
+ ctx_len = kwargs.pop("ctx_len", None)
+ if comp_ctx_lengths_prefill:
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations(
+ comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len
+ )
+
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if qaic_config is not None:
@@ -2175,13 +2330,22 @@ def from_pretrained(
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
- model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
+ model,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ kv_offload=kv_offload,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ **kwargs,
)
return cls(
model,
continuous_batching=continuous_batching,
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
**kwargs,
)
@@ -2231,6 +2395,10 @@ def export(self, export_dir: Optional[str] = None) -> str:
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
}
+ if self.comp_ctx_lengths_prefill is not None:
+ example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long)
+ dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"}
+
if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
pkv_dynamic_axes = {
0: "full_batch_size" if self.continuous_batching else "batch_size",
@@ -2376,6 +2544,7 @@ def build_prefill_specialization(
self,
prefill_seq_len: int = 32,
ctx_len: int = 128,
+ comp_ctx_lengths: Optional[int] = None,
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
@@ -2407,6 +2576,9 @@ def build_prefill_specialization(
"ctx_len": ctx_len,
"num_logits_to_keep": 1 if self.is_tlm else None,
}
+ if comp_ctx_lengths is not None:
+ spec["comp_ctx_lengths"] = comp_ctx_lengths
+
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
else:
@@ -2419,6 +2591,7 @@ def build_decode_specialization(
self,
prefill_seq_len: int = 32,
ctx_len: int = 128,
+ comp_ctx_lengths: Optional[int] = None,
batch_size: int = 1,
kv_cache_batch_size: Optional[int] = None,
full_batch_size: Optional[int] = None,
@@ -2448,7 +2621,7 @@ def build_decode_specialization(
A dictionary defining the decode specialization, or None if it would be a duplicate
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
"""
- if prefill_seq_len == 1 and not self.continuous_batching:
+ if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None:
return None # Avoid duplication with prefill
spec = {
"batch_size": full_batch_size if self.continuous_batching else batch_size,
@@ -2456,6 +2629,8 @@ def build_decode_specialization(
"ctx_len": ctx_len,
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
}
+ if comp_ctx_lengths is not None:
+ spec["comp_ctx_lengths"] = comp_ctx_lengths
if self.continuous_batching:
spec["full_batch_size"] = kv_cache_batch_size
@@ -2470,6 +2645,8 @@ def compile(
*,
prefill_seq_len: int = 32,
ctx_len: int = 128,
+ comp_ctx_lengths_prefill: Optional[List[int]] = None,
+ comp_ctx_lengths_decode: Optional[List[int]] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
@@ -2557,6 +2734,23 @@ def compile(
If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models.
"""
+ # For comp_ctx_lengths Disaggregated applications
+ if self.comp_ctx_lengths_prefill is None:
+ if comp_ctx_lengths_prefill is not None:
+ import ast
+
+ if isinstance(comp_ctx_lengths_prefill, str):
+ try:
+ # Safely evaluate the string to a Python list for disaggregated input
+ self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill)
+ self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode)
+
+ except (ValueError, SyntaxError):
+ raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.")
+ else:
+ self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
+ self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
+
# --- Validation ---
if prefill_only is not None and not isinstance(prefill_only, bool):
raise TypeError("`prefill_only` must be a boolean.")
@@ -2587,26 +2781,58 @@ def compile(
# --- Specializations ---
specializations = []
if prefill_only is None or prefill_only or prefill_seq_len == 1:
- specializations.append(
- self.build_prefill_specialization(
+ if self.comp_ctx_lengths_prefill is not None:
+ # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
+ for i in range(0, len(self.comp_ctx_lengths_prefill)):
+ specializations.append(
+ self.build_prefill_specialization(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ comp_ctx_lengths=self.comp_ctx_lengths_prefill[i],
+ batch_size=batch_size,
+ kv_cache_batch_size=kv_cache_batch_size,
+ full_batch_size=full_batch_size,
+ )
+ )
+
+ else:
+ specializations.append(
+ self.build_prefill_specialization(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ batch_size=batch_size,
+ kv_cache_batch_size=kv_cache_batch_size,
+ full_batch_size=full_batch_size,
+ )
+ )
+
+ if prefill_only is None or not prefill_only:
+ if self.comp_ctx_lengths_decode is not None:
+ # Adding elements from self.comp_ctx_lengths_decode to decode_specialization
+ for i in range(0, len(self.comp_ctx_lengths_decode)):
+ decode_spec = self.build_decode_specialization(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ comp_ctx_lengths=self.comp_ctx_lengths_decode[i],
+ batch_size=batch_size,
+ kv_cache_batch_size=kv_cache_batch_size,
+ full_batch_size=full_batch_size,
+ num_speculative_tokens=num_speculative_tokens,
+ )
+ if decode_spec:
+ specializations.append(decode_spec)
+
+ else:
+ decode_spec = self.build_decode_specialization(
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
batch_size=batch_size,
kv_cache_batch_size=kv_cache_batch_size,
full_batch_size=full_batch_size,
+ num_speculative_tokens=num_speculative_tokens,
)
- )
- if prefill_only is None or not prefill_only:
- decode_spec = self.build_decode_specialization(
- prefill_seq_len=prefill_seq_len,
- ctx_len=ctx_len,
- batch_size=batch_size,
- kv_cache_batch_size=kv_cache_batch_size,
- full_batch_size=full_batch_size,
- num_speculative_tokens=num_speculative_tokens,
- )
- if decode_spec:
- specializations.append(decode_spec)
+ if decode_spec:
+ specializations.append(decode_spec)
# --- Compilation ---
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
@@ -2684,6 +2910,8 @@ def generate(
tokenizer,
self.qpc_path,
prompt=prompts,
+ comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=self.comp_ctx_lengths_decode,
device_id=device_id,
generation_len=generation_len,
is_tlm=self.is_tlm,
diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py
index 18557f1ca..97dc015c0 100644
--- a/QEfficient/transformers/models/phi/modeling_phi.py
+++ b/QEfficient/transformers/models/phi/modeling_phi.py
@@ -67,6 +67,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
@@ -104,8 +105,16 @@ def forward(
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
- # Update the cache_kwargs with position_ids for Cloud AI 100
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -140,6 +149,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
@@ -181,6 +191,7 @@ def forward(
position_ids=position_ids,
batch_index=batch_index,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@@ -213,6 +224,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -274,6 +286,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -316,6 +329,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -370,6 +384,7 @@ def forward(
position_ids=position_ids,
batch_index=batch_index,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py
index 602a73c84..2e2a728cf 100644
--- a/QEfficient/transformers/models/phi3/modeling_phi3.py
+++ b/QEfficient/transformers/models/phi3/modeling_phi3.py
@@ -140,6 +140,7 @@ def forward(
batch_index: Optional[torch.LongTensor] = None,
position_ids=Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -162,6 +163,8 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
@@ -169,6 +172,7 @@ def forward(
"cache_position": cache_position,
"batch_index": batch_index,
"position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -202,6 +206,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -243,6 +248,7 @@ def forward(
position_ids=position_ids,
batch_index=batch_index,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@@ -277,6 +283,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -332,6 +339,7 @@ def forward(
position_ids=position_ids,
batch_index=batch_index,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@@ -376,6 +384,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -423,6 +432,7 @@ def forward(
batch_index=batch_index,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py
index 00a3989d8..057cfe173 100644
--- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py
+++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py
@@ -152,6 +152,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -170,8 +171,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -205,6 +214,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
@@ -240,6 +250,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -279,6 +290,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
@@ -337,6 +349,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -380,6 +393,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -403,6 +417,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
index bf3defc1a..e77bf3707 100644
--- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
+++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py
@@ -202,6 +202,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
@@ -220,8 +221,16 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
+ if comp_ctx_lengths is not None:
+ attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
# sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "batch_index": batch_index,
+ "position_ids": position_ids,
+ "CCL": attention_mask.shape[-1],
+ }
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
@@ -248,6 +257,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
batch_index: Optional[torch.LongTensor] = None,
@@ -291,6 +301,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
use_cache=use_cache,
@@ -328,6 +339,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -378,6 +390,7 @@ def forward(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
@@ -417,6 +430,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
+ comp_ctx_lengths: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
@@ -475,6 +489,7 @@ def forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
+ comp_ctx_lengths=comp_ctx_lengths,
inputs_embeds=inputs_embeds,
batch_index=batch_index,
use_cache=use_cache,
diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py
new file mode 100644
index 000000000..dbfb08926
--- /dev/null
+++ b/QEfficient/utils/check_ccl_specializations.py
@@ -0,0 +1,43 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from typing import List, Optional
+
+
+def process_ccl_specializations(
+ ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None
+):
+ if ctx_len is None:
+ raise TypeError("`ctx_len` is required when loading the model.")
+ if ccl_prefill is None:
+ ccl_prefill = [ctx_len]
+ if ccl_decode is None:
+ ccl_decode = [ctx_len]
+
+ # Step 1: Cap values to ctx_len
+ ccl_prefill = [min(x, ctx_len) for x in ccl_prefill]
+ ccl_decode = [min(x, ctx_len) for x in ccl_decode]
+
+ # Step 2: Remove duplicates within each list
+ ccl_prefill = list(set(ccl_prefill))
+ ccl_decode = list(set(ccl_decode))
+
+ # Step 3: Ensure no overlap between ccl_prefill and ccl_decode
+ updated_prefill = []
+ for val in ccl_prefill:
+ while val in ccl_decode or val in updated_prefill:
+ val -= 1
+ if val < 0:
+ break # Prevent negative values
+ if val >= 0:
+ updated_prefill.append(val)
+
+ # Step 4: Sort both lists
+ updated_prefill.sort()
+ ccl_decode.sort()
+
+ return updated_prefill, ccl_decode
diff --git a/examples/ccl_image_text_to_text_inference.py b/examples/ccl_image_text_to_text_inference.py
new file mode 100644
index 000000000..cbfe20e7a
--- /dev/null
+++ b/examples/ccl_image_text_to_text_inference.py
@@ -0,0 +1,135 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import requests
+from PIL import Image
+from transformers import AutoProcessor, TextStreamer
+
+from QEfficient import QEFFAutoModelForImageTextToText
+
+# Add HuggingFace Token to access the model
+HF_TOKEN = ""
+
+
+def run_model(
+ model_name,
+ token,
+ query,
+ image_url,
+ kv_offload=False,
+ prefill_seq_len=32,
+ ctx_len=512,
+ comp_ctx_lengths_prefill=None,
+ comp_ctx_lengths_decode=None,
+ generation_len=128,
+ img_size=560,
+ num_cores=16,
+ num_devices=1,
+):
+ ## STEP - 1 Load the Processor and Model
+
+ processor = AutoProcessor.from_pretrained(model_name, token=token)
+
+ # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs.
+ # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs.
+ # The outputs of the Vision Encoder are then passed to the Language model via host in this case.
+
+ model = QEFFAutoModelForImageTextToText.from_pretrained(
+ model_name,
+ token=token,
+ attn_implementation="eager",
+ kv_offload=kv_offload,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ )
+
+ ## STEP - 2 Export & Compile the Model
+
+ model.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ img_size=img_size,
+ num_cores=num_cores,
+ num_devices=num_devices,
+ mxfp6_matmul=False,
+ )
+
+ ## STEP - 3 Load and process the inputs for Inference
+
+ image = Image.open(requests.get(image_url, stream=True).raw)
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": query},
+ ],
+ }
+ ]
+ input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)]
+
+ inputs = processor(
+ text=input_text,
+ images=image,
+ return_tensors="pt",
+ add_special_tokens=False,
+ padding="max_length",
+ max_length=prefill_seq_len,
+ )
+
+ ## STEP - 4 Run Inference on the compiled model
+
+ streamer = TextStreamer(processor.tokenizer)
+ output_statistics = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len)
+ print(output_statistics)
+
+
+if __name__ == "__main__":
+ # Model name and Input parameters
+ # model_name = "llava-hf/llava-1.5-7b-hf"
+ model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+ query = "Describe this image."
+ image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
+
+ # Compilation parameters for the model
+ kv_offload = True
+ prefill_seq_len = 32
+ ctx_len = 8192
+ generation_len = 128
+ # img_size = 336
+ img_size = 560
+ num_cores = 16
+ num_devices = 4
+ comp_ctx_lengths_prefill = [4096]
+ comp_ctx_lengths_decode = [6144, ctx_len]
+
+ run_model(
+ model_name=model_name,
+ token=HF_TOKEN,
+ query=query,
+ kv_offload=kv_offload,
+ image_url=image_url,
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ generation_len=generation_len,
+ img_size=img_size,
+ num_cores=num_cores,
+ num_devices=num_devices,
+ )
+
+
+"""
+Expected Response:
+
+This image depicts a charming anthropomorphic rabbit standing on a dirt path in front of a picturesque stone cottage, surrounded by a serene landscape.
+
+The rabbit, with its light brown fur and distinctive long ears, is attired in a stylish blue coat, brown vest, and tan pants, exuding a sense of sophistication. The dirt path, flanked by vibrant flowers and lush greenery, leads to the cottage, which features a thatched roof and a chimney, adding to the rustic charm of the scene. In the background, rolling hills and trees create a breathtaking panorama, while the sky above is a brilliant blue with white clouds, completing the
+
+"""
diff --git a/examples/ccl_llama4_example.py b/examples/ccl_llama4_example.py
new file mode 100644
index 000000000..06097558e
--- /dev/null
+++ b/examples/ccl_llama4_example.py
@@ -0,0 +1,126 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import transformers
+from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer
+
+from QEfficient import QEFFAutoModelForImageTextToText
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+config = AutoConfig.from_pretrained(model_id)
+# For Testing Purpose Only
+config.text_config.num_hidden_layers = 4
+config.vision_config.num_hidden_layers = 2
+
+model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config)
+model.eval()
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
+processor = AutoProcessor.from_pretrained(model_id)
+
+### For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ###
+ctx_len = 8192
+comp_ctx_lengths_prefill = [3072]
+comp_ctx_lengths_decode = [4096, 6144, ctx_len]
+
+qeff_model = QEFFAutoModelForImageTextToText(
+ model,
+ kv_offload=False,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+)
+
+### use skip_vision=Ture, if want to run only text, ow false ###
+skip_vision = False
+
+if skip_vision:
+ ## Only Text ##
+ qeff_model.compile(
+ prefill_seq_len=128,
+ ctx_len=ctx_len,
+ img_size=336,
+ num_cores=16,
+ num_devices=4,
+ max_num_tiles=17,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ aic_enable_depth_first=True,
+ skip_vision=True,
+ mos=1,
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "text",
+ "text": "Can you describe the image in detail.",
+ },
+ ],
+ },
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+
+ streamer = TextStreamer(tokenizer)
+ output = qeff_model.generate(inputs=inputs, generation_len=700)
+ print(output.generated_ids)
+ print(tokenizer.batch_decode(output.generated_ids))
+ print(output)
+
+else:
+ ## Vision + Text ##
+ qeff_model.compile(
+ prefill_seq_len=128,
+ ctx_len=ctx_len,
+ img_size=336,
+ num_cores=16,
+ num_devices=4,
+ max_num_tiles=17,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ aic_enable_depth_first=True,
+ mos=1,
+ )
+
+ ### IMAGE + TEXT ###
+ image_url = (
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": image_url},
+ {"type": "text", "text": "Can you describe the image in detail."},
+ ],
+ },
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
+ streamer = TextStreamer(tokenizer)
+ output = qeff_model.generate(inputs=inputs, generation_len=1024)
+ print(output.generated_ids)
+ print(tokenizer.batch_decode(output.generated_ids))
+ print(output)
+ print()
diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py
new file mode 100644
index 000000000..82bec9ced
--- /dev/null
+++ b/examples/compute_context_length.py
@@ -0,0 +1,50 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ##
+
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+
+## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ##
+## - The first Prefill_ccl_len numbers in this list are the context lengths that will be used during prefilling. ##
+## - During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ##
+
+
+ctx_len = 2048
+comp_ctx_lengths_prefill = [256]
+comp_ctx_lengths_decode = [512, 1024, ctx_len]
+
+model_name = "Qwen/Qwen2.5-7B"
+model = QEFFAutoModelForCausalLM.from_pretrained(
+ model_name,
+ continuous_batching=True,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+)
+
+# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed.
+model.compile(
+ prefill_seq_len=128,
+ ctx_len=ctx_len,
+ num_cores=16,
+ num_devices=4,
+ full_batch_size=1,
+ mxint8_kv_cache=True,
+ mxfp6_matmul=True,
+)
+
+# Create tokenizer and run model.generate and passes the input prompts to it.
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model.generate(
+ prompts=[
+ "My name is ",
+ ],
+ tokenizer=tokenizer,
+)
diff --git a/examples/gemma3_example/ccl_gemma3_mm.py b/examples/gemma3_example/ccl_gemma3_mm.py
new file mode 100644
index 000000000..c7e2b8e83
--- /dev/null
+++ b/examples/gemma3_example/ccl_gemma3_mm.py
@@ -0,0 +1,119 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import torch
+import transformers
+from transformers import AutoConfig, AutoProcessor
+
+from QEfficient import QEFFAutoModelForImageTextToText
+
+# Change model_id to "google/gemma-3-27b-it" for 27B model
+model_id = "google/gemma-3-4b-it"
+config = AutoConfig.from_pretrained(model_id)
+# For Testing Purpose Only
+# config.text_config.num_hidden_layers = 1
+# config.vision_config.num_hidden_layers = 2
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
+processor = AutoProcessor.from_pretrained(model_id)
+
+# pass HF_TOKEN if gated model
+# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ###
+ctx_len = 8192
+comp_ctx_lengths_prefill = [3072]
+comp_ctx_lengths_decode = [4096, 6144, ctx_len]
+
+qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
+ model_id,
+ config=config,
+ attn_implementation="eager",
+ kv_offload=True,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+)
+
+### use skip_vision=Ture, if want to run only text, or false ###
+skip_vision = True
+
+if skip_vision:
+ ## Only Text ##
+ qeff_model.compile(
+ prefill_seq_len=128,
+ ctx_len=ctx_len,
+ img_size=896,
+ num_cores=16,
+ num_devices=4,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ aic_enable_depth_first=True,
+ skip_vision=True,
+ mos=1,
+ node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml",
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Describe the transformers architecture in LLMs."},
+ ],
+ },
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+
+ output = qeff_model.generate(inputs=inputs, generation_len=100)
+ print(tokenizer.batch_decode(output.generated_ids))
+ print(output)
+
+else:
+ ## Vision + Text ##
+ qeff_model.compile(
+ prefill_seq_len=128,
+ ctx_len=ctx_len,
+ img_size=896,
+ num_cores=16,
+ num_devices=4,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+ aic_enable_depth_first=True,
+ mos=1,
+ node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml",
+ )
+
+ ### IMAGE + TEXT ###
+ image_url = (
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
+ )
+
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": image_url},
+ {"type": "text", "text": "Can you describe the image in detail."},
+ ],
+ },
+ ]
+
+ inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ )
+ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
+ output = qeff_model.generate(inputs=inputs, generation_len=100)
+ print(tokenizer.batch_decode(output.generated_ids))
+ print(output)
diff --git a/examples/granite_example/ccl_granite_vision_inference.py b/examples/granite_example/ccl_granite_vision_inference.py
new file mode 100644
index 000000000..e03b94a5e
--- /dev/null
+++ b/examples/granite_example/ccl_granite_vision_inference.py
@@ -0,0 +1,127 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+import requests
+from PIL import Image
+from transformers import AutoProcessor, TextStreamer
+
+from QEfficient import QEFFAutoModelForImageTextToText
+
+# Add HuggingFace Token to access the model
+HF_TOKEN = ""
+
+
+def run_model(
+ model_name,
+ token,
+ query,
+ image_url,
+ kv_offload=False,
+ prefill_seq_len=5500,
+ ctx_len=6000,
+ comp_ctx_lengths_prefill=None,
+ comp_ctx_lengths_decode=None,
+ generation_len=128,
+ img_size=384,
+ num_cores=16,
+ num_devices=1,
+):
+ ## STEP - 1 Load the Processor and Model
+
+ processor = AutoProcessor.from_pretrained(model_name, token=token)
+
+ # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed.
+ # The `kv_offload` flag should always be set to True.
+ # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs.
+ # The outputs of the Vision Encoder are then passed to the Language model via host in this case.
+
+ model = QEFFAutoModelForImageTextToText.from_pretrained(
+ model_name,
+ token=token,
+ kv_offload=kv_offload,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ )
+
+ ## STEP - 2 Export & Compile the Model
+
+ model.compile(
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ img_size=img_size,
+ num_cores=num_cores,
+ num_devices=num_devices,
+ mxfp6_matmul=False,
+ )
+
+ ## STEP - 3 Load and process the inputs for Inference
+
+ # We are resizing the image to (w x h) (1610 x 1109) so that any image can work on the model irrespective of image dimensssions
+ # we have a fixed size of height 1109 and width 1610
+
+ image = Image.open(requests.get(image_url, stream=True).raw)
+ image = image.resize((1610, 1109))
+
+ messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}]
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt")
+
+ ## STEP - 4 Run Inference on the compiled model
+
+ streamer = TextStreamer(processor.tokenizer)
+ output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len)
+ print(output)
+
+
+if __name__ == "__main__":
+ # Model name and Input parameters
+ model_name = "ibm-granite/granite-vision-3.2-2b"
+
+ # Please add prompt here
+ query = "Describe the image"
+
+ # Please pass image url or image path .The format of the image should be jpg.
+ image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+
+ # Compilation parameters for the model
+ kv_offload = True
+ prefill_seq_len = 5500
+ ctx_len = 8192
+ generation_len = 128
+ img_size = 384
+ num_cores = 16
+ num_devices = 4
+ ctx_len = 8192
+ comp_ctx_lengths_prefill = [5500]
+ comp_ctx_lengths_decode = [6144, ctx_len]
+
+ run_model(
+ model_name=model_name,
+ token=HF_TOKEN,
+ query=query,
+ kv_offload=kv_offload,
+ image_url=image_url,
+ prefill_seq_len=prefill_seq_len,
+ ctx_len=ctx_len,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ generation_len=generation_len,
+ img_size=img_size,
+ num_cores=num_cores,
+ num_devices=num_devices,
+ )
+
+
+"""
+Expected Response:
+
+The image depicts two cats lying on a pink blanket that is spread out on a red couch. The cats are positioned in a relaxed manner, with their bodies stretched out and their heads resting on the blanket.
+The cat on the left is a smaller, tabby cat with a mix of black, gray, and white fur. It has a long, slender body and a distinctive tail that is curled up near its tail end. The cat on the right is a larger,
+tabby cat with a mix of gray, black, and brown fur. It has
+
+"""
diff --git a/examples/granite_example/ccl_granitemoe_inference.py b/examples/granite_example/ccl_granitemoe_inference.py
new file mode 100644
index 000000000..57668ca24
--- /dev/null
+++ b/examples/granite_example/ccl_granitemoe_inference.py
@@ -0,0 +1,40 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.utils.constants import Constants
+
+model_name = "ibm-research/PowerMoE-3b"
+"""
+# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function
+# We will use prompt_len=1 for compilation for both cb and non-cb inference
+"""
+
+ctx_len = 2048
+comp_ctx_lengths_prefill = [256]
+comp_ctx_lengths_decode = [512, 1024, ctx_len]
+
+model = QEFFAutoModelForCausalLM.from_pretrained(
+ model_name,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ continuous_batching=False,
+)
+model.compile(
+ prefill_seq_len=1,
+ ctx_len=ctx_len,
+ batch_size=1,
+ num_cores=16,
+ num_devices=4,
+ mxfp6_matmul=False,
+ mxint8_kv_cache=False,
+)
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)
diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py
new file mode 100644
index 000000000..5595d26cd
--- /dev/null
+++ b/examples/intern_example/ccl_internvl_inference.py
@@ -0,0 +1,286 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from io import BytesIO
+from typing import List
+
+import requests
+import torch
+import torch.nn as nn
+import torchvision.transforms as T
+from PIL import Image
+from torchvision.transforms.functional import InterpolationMode
+from transformers import AutoTokenizer, TextStreamer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.utils.logging_utils import logger
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+# Process the input messages to generate prompt for the model.
+def get_prompt(messages) -> str:
+ """Get the prompt for generation."""
+ ## Chat template used for InternVL
+ system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
+ sep = "<|im_end|>\n"
+
+ ret = system_prompt + sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + sep
+ else:
+ ret += role
+ return ret
+
+
+# Processor class for InternVL models
+class InternProcessor:
+ """
+ InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor.
+ The methods used here are borrowed from the original InternVL modelling files.
+ "https://huggingface.co/OpenGVLab/InternVL2_5-1B/"
+ """
+
+ def __init__(self, model: nn.Module, tokenizer):
+ self.model = model
+ image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size
+ patch_size = self.model.config.vision_config.patch_size
+ self.template = model.config.template
+ self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2))
+ self.tokenizer = tokenizer
+
+ def build_transform(self, input_size):
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD),
+ ]
+ )
+ return transform
+
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float("inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+ def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if i * j <= max_num and i * j >= min_num
+ )
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = self.find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
+ )
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size,
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ return processed_images
+
+ def load_image(self, image, input_size=448, max_num=12):
+ transform = self.build_transform(input_size=input_size)
+ images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
+ pixel_values = [transform(image) for image in images]
+ pixel_values = torch.stack(pixel_values)
+ return pixel_values
+
+ def __call__(
+ self,
+ pixel_values,
+ question,
+ messages,
+ roles,
+ history=None,
+ num_patches_list=None,
+ IMG_START_TOKEN="
",
+ IMG_END_TOKEN="",
+ IMG_CONTEXT_TOKEN="",
+ verbose=False,
+ ) -> str:
+ if history is None and pixel_values is not None and "" not in question:
+ question = "\n" + question
+ if num_patches_list is None:
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
+ img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ self.model.img_context_token_id = img_context_token_id
+
+ messages.append([roles[0], question])
+ messages.append([roles[1], None])
+ query = get_prompt(messages)
+ if verbose and pixel_values is not None:
+ image_bs = pixel_values.shape[0]
+ logger.info(f"dynamic ViT batch size: {image_bs}")
+
+ for num_patches in num_patches_list:
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
+ query = query.replace("", image_tokens, 1)
+ return query
+
+
+def run_intern_on_aic(
+ model_name,
+ prompt,
+ image_url,
+ messages,
+ roles,
+ kv_offload=False,
+ prefill_seq_len=3840,
+ num_devices=1,
+ num_cores=16,
+):
+ ## STEP 1 -- LOAD THE MODEL
+
+ # The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface.
+ # To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`.
+
+ ctx_len = 8192
+ comp_ctx_lengths_prefill = [4096]
+ comp_ctx_lengths_decode = [6144, ctx_len]
+
+ # model = QEFFAutoModelForCausalLM.from_pretrained(model_name, kv_offload=kv_offload, trust_remote_code=True)
+
+ model = QEFFAutoModelForCausalLM.from_pretrained(
+ model_name,
+ kv_offload=kv_offload,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ trust_remote_code=True,
+ )
+
+ ## STEP 2 -- EXPORT & COMPILE THE MODEL
+
+ model.compile(
+ num_cores=num_cores,
+ num_devices=num_devices,
+ ctx_len=ctx_len,
+ prefill_seq_len=prefill_seq_len,
+ mxfp6_matmul=False,
+ )
+
+ ## STEP 3 -- SETUP THE PROCESSOR
+
+ # InternVL doesn't have an AutoProcessor yet, so we will use our own processor class "InternProcessor"
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
+ internProcessor = InternProcessor(model.model, tokenizer)
+
+ ## STEP 4 -- PREPROCESS THE INPUTS
+
+ img = requests.get(image_url, stream=True)
+ image = Image.open(BytesIO(img.content)).convert("RGB")
+
+ # Images are resized to (1000, 747) for inference
+ image = image.resize((1000, 747))
+
+ # preprocess the resized image
+ pixel_values = internProcessor.load_image(image, max_num=12)
+ question = "\n" + prompt
+ query = internProcessor(pixel_values, question, messages, roles)
+ inputs = tokenizer(
+ query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right"
+ )
+
+ inputs["pixel_values"] = pixel_values
+
+ ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION
+ streamer = TextStreamer(tokenizer)
+ model.generate(inputs=inputs, streamer=streamer, generation_len=128)
+
+
+if __name__ == "__main__":
+ model_name = "OpenGVLab/InternVL2_5-1B"
+
+ # Chat Template information for prompt preprocessing
+ messages: List[List[str]] = []
+ roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
+
+ # Inputs for the model
+ prompt = "Please describe the image in detail."
+ image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg"
+
+ ## Compilation parameters
+
+ # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs.
+ # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs.
+ # The outputs of the Vision Encoder are then passed to the Language model via host in this case.
+
+ kv_offload = False
+
+ # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with
+ # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to
+ # incorporate the memory for the merged embeddings.
+
+ prefill_seq_len = 3840
+ num_devices = 4
+ num_cores = 16
+
+ run_intern_on_aic(
+ model_name=model_name,
+ prompt=prompt,
+ image_url=image_url,
+ messages=messages,
+ roles=roles,
+ kv_offload=kv_offload,
+ prefill_seq_len=prefill_seq_len,
+ num_devices=num_devices,
+ num_cores=num_cores,
+ )
+
+
+"""
+Expected Response:
+
+The image is a promotional graphic for Microsoft Azure. It features a blue background with a hexagonal pattern on the left side. The hexagons are white and are arranged in a way that suggests a network or connectivity theme.
+
+On the right side of the image, the Microsoft Azure logo is prominently displayed. The logo consists of the Azure name in white, with the Microsoft logo above it, which includes four colored squares (blue, green, yellow, and red). Below the logo, the word "Azure" is written in large white letters.
+
+Below the logo, there is text that reads:
+- "By Dinesh Kumar Wick
+"""
diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py
new file mode 100644
index 000000000..ba0e21c35
--- /dev/null
+++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py
@@ -0,0 +1,42 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+from QEfficient.utils.constants import Constants
+
+model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+"""
+# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function
+# We will use prompt_len=1 for compilation for both cb and non-cb inference
+"""
+
+ctx_len = 2048
+
+comp_ctx_lengths_prefill = [256]
+comp_ctx_lengths_decode = [512, 1024, ctx_len]
+
+model = QEFFAutoModelForCausalLM.from_pretrained(
+ model_name,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ continuous_batching=True,
+)
+model.compile(
+ prefill_seq_len=1,
+ ctx_len=ctx_len,
+ full_batch_size=1,
+ num_cores=16,
+ num_devices=4,
+ mxfp6_matmul=True,
+ mxint8_kv_cache=True,
+ mos=1,
+)
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer)
diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py
new file mode 100644
index 000000000..e145ad698
--- /dev/null
+++ b/tests/transformers/test_comp_ctx_length.py
@@ -0,0 +1,193 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import copy
+import os
+from time import perf_counter
+
+import onnx
+import pytest
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+
+from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
+
+configs = [
+ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params
+ ("gpt2", 256, 2, 4, 128, 512, 127, {}),
+ ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
+ ("falcon", 256, 2, 4, 128, 512, 127, {}),
+ ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
+ ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+ ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+ ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+ ("mpt", 256, 2, 4, 128, 512, 127, {}),
+ ("phi", 256, 2, 4, 128, 512, 127, {}),
+ ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}),
+ ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+ ("starcoder2", 256, 2, 4, 128, 512, 127, {}),
+ ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+]
+
+configs = [
+ AutoConfig.for_model(
+ model_name,
+ max_position_embeddings=max_position_embeddings,
+ num_hidden_layers=num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ vocab_size=vocab_size,
+ **additional_params,
+ )
+ for (
+ model_name,
+ max_position_embeddings,
+ num_hidden_layers,
+ num_attention_heads,
+ hidden_size,
+ intermediate_size,
+ vocab_size,
+ additional_params,
+ ) in configs
+]
+config_ids = [x.model_type for x in configs]
+
+model_kwargs = {"attn_implementation": "eager"}
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+def test_causal_lm_unsupported(cb):
+ model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt"))
+ with pytest.warns():
+ QEFFAutoModelForCausalLM(model, cb)
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_init(config, cb):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ qeff_model = QEFFAutoModelForCausalLM(model, cb)
+ with pytest.raises(TypeError):
+ QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb)
+ assert qeff_model.model.__class__.__name__.startswith("QEff")
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_pretrained(config, cb, tmp_path):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ model.save_pretrained(tmp_path)
+
+ qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb)
+ assert qeff_model.model.__class__.__name__.startswith("QEff")
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_hash(config, cb):
+ hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
+ hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
+
+ assert hash_0_0 == hash_0_1
+
+ cfg1 = copy.deepcopy(config)
+ cfg1.num_hidden_layers -= 1
+ hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash
+ cfg2 = copy.deepcopy(config)
+ cfg2.num_hidden_layers -= 1
+ hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash
+ assert hash_1_0 == hash_1_1
+
+ assert hash_0_0 != hash_1_0
+
+ if cb:
+ hash_0_no_cb = QEFFAutoModelForCausalLM(
+ AutoModelForCausalLM.from_config(config, **model_kwargs), False
+ ).model_hash
+ assert hash_0_0 != hash_0_no_cb
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_export(config, cb, tmp_path):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ ctx_len = 2048
+ comp_ctx_lengths_prefill = [256]
+ comp_ctx_lengths_decode = [512, 1024, ctx_len]
+
+ qeff_model = QEFFAutoModelForCausalLM(
+ model,
+ cb,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ )
+ qeff_model.export(tmp_path)
+ model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
+ assert model_path.is_dir()
+ assert qeff_model.onnx_path.is_file()
+ assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
+
+ # Check if the KV-cache inputs and outputs are created
+ onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False)
+ retained_output_names = {
+ x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
+ }
+ retained_output_names.issubset({x.name for x in onnx_model.graph.input})
+
+ # Check if there is no re-export
+ start = perf_counter()
+ qeff_model.export(tmp_path)
+ end = perf_counter()
+ export_time = end - start
+ assert export_time < 2.0
+
+
+@pytest.fixture
+def tmp_cache(tmp_path, monkeypatch):
+ monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path)
+ yield tmp_path
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_compile(config, cb, tmp_cache):
+ model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+ ctx_len = 2048
+ comp_ctx_lengths_prefill = [256]
+ comp_ctx_lengths_decode = [512, 1024, ctx_len]
+ qeff_model = QEFFAutoModelForCausalLM(
+ model,
+ cb,
+ comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
+ comp_ctx_lengths_decode=comp_ctx_lengths_decode,
+ ctx_len=ctx_len,
+ )
+ compile_params = {"prefill_seq_len": 8, "ctx_len": ctx_len}
+ if cb:
+ compile_params["full_batch_size"] = 32
+ compile_params["batch_size"] = 8
+ qeff_model.compile(**compile_params)
+ model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash)
+
+ # Check if ONNX is exported properly
+ assert model_path.is_dir()
+ assert qeff_model.onnx_path.is_file()
+ assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
+
+ # Check if QPC is compiled properly
+ assert qeff_model.qpc_path.is_dir()
+ assert (qeff_model.qpc_path / "programqpc.bin").is_file()
+ assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash
+
+ # Check if there is no re-compilation
+ start = perf_counter()
+ qeff_model.compile(**compile_params)
+ end = perf_counter()
+ compile_time = end - start
+ assert compile_time < 2.0
+ assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))