Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def check_qaic_sdk():
QEFFCommonLoader,
)
from QEfficient.compile.compile_helper import compile

# Imports for the diffusers
from QEfficient.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import QEFFStableDiffusionPipeline
from QEfficient.diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion3 import (
QEFFStableDiffusion3Pipeline,
)
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
Expand All @@ -67,6 +73,8 @@ def check_qaic_sdk():
"QEFFAutoModelForImageTextToText",
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEFFStableDiffusionPipeline",
"QEFFStableDiffusion3Pipeline",
]

else:
Expand Down
165 changes: 43 additions & 122 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,25 @@
#
# ----------------------------------------------------------------------------

import gc
import hashlib
import inspect
import logging
import shutil
import subprocess
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional

import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
export_wrapper,
generate_mdp_partition_config,
hash_dict_params,
load_json,
)
from QEfficient.utils import constants, create_json, generate_mdp_partition_config, load_json # dump_qconfig #TODO: debug and enable
from QEfficient.utils.cache import QEFF_HOME, to_hashable

Check failure on line 26 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

QEfficient/base/modeling_qeff.py:8:1: I001 Import block is un-sorted or un-formatted

logger = logging.getLogger(__name__)

Expand All @@ -53,19 +45,12 @@
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
self.model_architecture = (
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
) or None

# Flag for checking if weights are offloaded
self._is_weights_offloaded: bool = False

# Apply the transformations
any_transformed = False
Expand All @@ -78,48 +63,14 @@
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

def _offload_model_weights(self, offload_pt_weights) -> bool:
"""
Clear PyTorch weights after export if offload_pt_weights is set to True

Returns:
bool: True if weights were successfully offloaded, False otherwise
"""
# Check if offloading is enabled and weights are not already offloaded
if offload_pt_weights and not self._is_weights_offloaded:
try:
self.model = self.model.to_empty(device="meta")
self._is_weights_offloaded = True
logger.info("Model weights offloaded to meta device")

gc.collect()
logger.info("PyTorch weights cleared after export")
return True

except Exception as e:
logger.error(f"Failed to offload model weights: {e}")
return False
return False

def _model_offloaded_check(self) -> None:
"""
Check if the model is in meta state or weights are offloaded.

Raises:
RuntimeError: If model is in meta state or if weights are offloaded
"""
if self._is_weights_offloaded or any(param.is_meta for param in self.model.parameters()):
error_msg = (
"Cannot re-export model: weights have been offloaded to save memory. "
"To re-export, please create a new model instance using from_pretrained() method."
)
logger.error(error_msg)
raise RuntimeError(error_msg)

@property
@abstractmethod
def model_name(self) -> str: ...

@property
@abstractmethod
def model_hash(self) -> str: ...

@abstractmethod
def export(self, export_dir: Optional[str] = None) -> Path:
"""
Expand All @@ -146,20 +97,14 @@
:mxfp6_matmul (bool): Use MXFP6 to compress weights for MatMul nodes to run faster on device. ``Defaults to False``.
:mxint8_kv_cache (bool): Use MXINT8 to compress KV-cache on device to access and update KV-cache faster. ``Defaults to False``.
:compiler_options: Pass any compiler option as input.

Following flag can be passed in compiler_options to enable QNN Compilation path.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``

for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:

- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16
- aic_hw_version=ai100 -> -aic-hw-version=ai100
- aic_hw_version=ai200 -> -aic-hw-version=ai200
Following flag can be passed in compiler_options to enable QNN Compilation path.
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16

``QEFFAutoModelForCausalLM`` Args:

:full_batch_size (int): Full batch size to allocate cache lines.
:batch_size (int): Batch size to compile for. ``Defaults to 1``.
:prefill_seq_len (int): Prefill sequence length to compile for. Prompt will be chunked according to this length.
Expand All @@ -169,7 +114,6 @@
:str: Path of the compiled ``qpc`` package.
"""

@export_wrapper
def _export(
self,
example_inputs: Dict[str, torch.Tensor],
Expand All @@ -178,15 +122,9 @@
export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms

This method:
1. Exports PyTorch model to ONNX using torch.onnx.export
2. Clears PyTorch weights after export
3. Applies ONNX transforms with reduced memory footprint
Export the Pytorch model to ONNX.

Args:
:example_inputs (dict): Sample inputs to trace the model.
Expand All @@ -195,30 +133,20 @@
:export_kwargs (dict): Additional arguments to be passed to `torch.onnx.export`.
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
:offload_pt_weights (bool): If True, offload PyTorch model weights to meta device
after successful export to reduce memory usage. Set to False if you need to
keep weights for further operations. Defaults to True.
Note:
Once weights are offloaded, the model cannot be re-exported. Create a new
instance using from_pretrained() for re-export.

"""
export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)
onnx_path = export_dir / f"{self.model_name}.onnx"

# Return early if ONNX already exists
if onnx_path.is_file():
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

# Setup temporary paths
tmp_onnx_dir = export_dir / "onnx_tmp"
tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)

# Create input_names from example_inputs

input_names = []
for param in inspect.signature(self.model.forward).parameters:
if param in example_inputs:
Expand Down Expand Up @@ -251,12 +179,11 @@
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
opset_version=17,
# verbose=True,
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)
logger.info("Pytorch export successful")

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
Expand All @@ -268,17 +195,17 @@

for transform in self._onnx_transforms:
model, transformed = transform.apply(model, **transform_kwargs)

model.metadata_props.append(
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))
)
logger.info("ONNX transforms applied")

onnx.save(model, onnx_path)
logger.info("Transformed ONNX saved")
logger.info("Transformed onnx saved")

except Exception as e:
logger.error(f"ONNX export or transforms failed: {e}")
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")

raise e

finally:
Expand All @@ -287,7 +214,7 @@
self.onnx_path = onnx_path
return onnx_path

@dump_qconfig
# @dump_qconfig
def _compile(
self,
onnx_path: Optional[str] = None,
Expand Down Expand Up @@ -317,12 +244,8 @@
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file. ``Defaults to None.``
:compiler_options: Pass any compiler option as input.
Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:

- aic_num_cores=16 -> -aic-num-cores=16
- convert_to_fp16=True -> -convert-to-fp16
- aic_hw_version=ai100 -> -aic-hw-version=ai100
- aic_hw_version=ai200 -> -aic-hw-version=ai200

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
Expand Down Expand Up @@ -354,13 +277,7 @@

return self.qpc_path

command = (
constants.COMPILER
+ [
f"-aic-hw-version={compiler_options.pop('aic_hw_version', compiler_options.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}"
]
+ [f"-m={onnx_path}"]
)
command = constants.COMPILER + [f"-m={onnx_path}"]

if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
Expand All @@ -383,16 +300,23 @@
else:
mdp_ts_json = None

compile_hash_params = {
"command": command,
"specializations": specializations,
"custom_io": custom_io,
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
}
compile_hash = hash_dict_params(compile_hash_params)
compile_hash = hashlib.sha256(to_hashable(command))

if specializations is not None:
compile_hash.update(to_hashable(specializations))

if custom_io is not None:
compile_hash.update(to_hashable(custom_io))

if num_speculative_tokens:
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))

# Hash the MDP partition config and the number of devices.
compile_hash.update(to_hashable(mdp_ts_json))
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))

# Check if already compiled
compile_hash = compile_hash.hexdigest()[:16]
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
qpc_path = compile_dir / "qpc"
qpc_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -429,6 +353,7 @@

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
print(command)
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
Expand All @@ -443,10 +368,6 @@
]
)
)
# Dump JSON file with hashed parameters
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
create_json(hashed_compile_params_path, compile_hash_params)
logger.info("Hashed parameters exported successfully.")

self.qpc_path = qpc_path

Expand Down
28 changes: 28 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional, Tuple

import numpy as np
import onnxslim
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -99,3 +100,30 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform:
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param onnx_base_dir: Base directory to load tensors
:param onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
if onnx_slim_transform:
transformed = True
slimmed_model = onnxslim.slim(model)
# Don't save here - let the caller handle saving
return slimmed_model, transformed
return model, transformed
Loading
Loading