@@ -34,9 +34,9 @@ def _is_auto_round_available():
3434_is_auto_round_available ()
3535
3636from auto_round import AutoRound , AutoRoundMLLM # pylint: disable=E0401
37+ from auto_round .compressors .mllm .eval import lmms_eval , mllm_eval
38+ from auto_round .compressors .mllm .template import Template , get_template
3739from auto_round .export .export_to_itrex .export import pack_model # pylint: disable=E0401
38- from auto_round .mllm import lmms_eval , mllm_eval
39- from auto_round .mllm .template import Template , get_template
4040from auto_round .schemes import QuantizationScheme
4141
4242from neural_compressor .torch .algorithms import Quantizer
@@ -50,11 +50,24 @@ class AutoRoundQuantizer(Quantizer):
5050
5151 def __init__ (
5252 self ,
53- quant_config : dict = {},
53+ bits : int = None ,
54+ group_size : int = None ,
55+ sym : bool = None ,
56+ data_type : str = None ,
57+ act_bits : int = None ,
58+ act_group_size : int = None ,
59+ act_sym : bool = None ,
60+ act_data_type : str = None ,
61+ act_dynamic : bool = None ,
62+ super_bits : int = None ,
63+ super_group_size : int = None ,
64+ quant_config : dict = {}, # for INC
65+ layer_config : dict [str , Union [str , dict , QuantizationScheme ]] = None ,
5466 enable_full_range : bool = False , ##for symmetric, TODO support later
5567 batch_size : int = 8 ,
5668 amp : bool = True ,
5769 device_map : str = None ,
70+ quant_lm_head : bool = False ,
5871 lr_scheduler = None ,
5972 dataset : Union [str , list , tuple , torch .utils .data .DataLoader ] = "NeelNanda/pile-10k" ,
6073 enable_quanted_input : bool = True ,
@@ -71,21 +84,14 @@ def __init__(
7184 gradient_accumulate_steps : int = 1 ,
7285 not_use_best_mse : bool = False ,
7386 dynamic_max_gap : int = - 1 ,
74- data_type : str = "int" ,
7587 scale_dtype : str = "fp16" ,
7688 to_quant_block_names : list = None ,
77- act_bits : int = 32 ,
78- act_group_size : int = None ,
79- act_sym : bool = None ,
80- act_dynamic : bool = True ,
81- act_data_type : Optional [str ] = None ,
8289 low_cpu_mem_usage : bool = False ,
8390 export_format : str = "itrex" ,
8491 # v0.4
8592 enable_norm_bias_tuning : bool = False ,
8693 enable_torch_compile : bool = None ,
8794 # mllm
88- is_mllm : bool = False ,
8995 quant_nontext_module : bool = False ,
9096 extra_data_dir : str = None ,
9197 image_processor = None ,
@@ -119,13 +125,15 @@ def __init__(
119125 bits (int): Number of bits for quantization (default is 4).
120126 group_size (int): Size of the quantization group (default is 128).
121127 sym (bool): Whether to use symmetric quantization. (default is None).
128+ layer_config (dict, optional): Layer-wise quantization config. Defaults to None.
122129 bits (int): Number of bits for quantization (default is 4).
123130 group_size (int): Size of the quantization group (default is 128).
124131 sym (bool): Whether symmetric quantization is to be used (default is False).
125132 enable_full_range (bool): Whether to enable full range quantization (default is False).
126133 batch_size (int): Batch size for training (default is 8).
127134 amp (bool): Whether to use automatic mixed precision (default is True).
128135 device_map: The device to be used for tuning (default is None).
136+ quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers. (default is False).
129137 lr_scheduler: The learning rate scheduler to be used.
130138 dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
131139 enable_quanted_input (bool): Whether to use the output of the previous quantized block as
@@ -155,7 +163,6 @@ def __init__(
155163 enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
156164 enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
157165 quant_nontext_module (bool): Whether to quantize nontext module.
158- is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
159166 extra_data_dir (str): The path for extra data such as images, audio or videos.
160167 processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
161168 decode the data that groups several modalities (among text, vision and audio).
@@ -170,12 +177,26 @@ def __init__(
170177 The quantized model.
171178 """
172179 super ().__init__ (quant_config )
173- self .tokenizer = "Placeholder" # for AutoRound initialization
180+ self .layer_config = layer_config
181+ self .output_dir = kwargs .pop ("output_dir" , "temp_auto_round" )
182+ self .tokenizer = kwargs .pop ("tokenizer" , "Placeholder" ) # for AutoRound initialization
174183 self .enable_full_range = enable_full_range
184+ self .bits = bits
185+ self .group_size = group_size
186+ self .sym = sym
187+ self .data_type = data_type
188+ self .act_bits = act_bits
189+ self .act_group_size = act_group_size
190+ self .act_sym = act_sym
191+ self .act_data_type = act_data_type
192+ self .act_dynamic = act_dynamic
193+ self .super_bits = super_bits
194+ self .super_group_size = super_group_size
175195 self .batch_size = batch_size
176196 self .amp = amp
177197 self .device = get_accelerator (kwargs .pop ("device" , "auto" )).name ()
178198 self .lr_scheduler = lr_scheduler
199+ self .dataset = dataset
179200 self .enable_quanted_input = enable_quanted_input
180201 self .enable_minmax_tuning = enable_minmax_tuning
181202 self .lr = lr
@@ -190,19 +211,12 @@ def __init__(
190211 self .gradient_accumulate_steps = gradient_accumulate_steps
191212 self .not_use_best_mse = not_use_best_mse
192213 self .dynamic_max_gap = dynamic_max_gap
193- self .data_type = data_type
194214 self .scale_dtype = scale_dtype
195215 self .to_quant_block_names = to_quant_block_names
196- self .act_bits = act_bits
197- self .act_group_size = act_group_size
198- self .act_sym = act_sym
199- self .act_dynamic = act_dynamic
200- self .act_data_type = act_data_type
201216 self .low_cpu_mem_usage = low_cpu_mem_usage
202217 self .export_format = export_format
203218 self .enable_norm_bias_tuning = enable_norm_bias_tuning
204219 self .enable_torch_compile = enable_torch_compile
205- self .is_mllm = is_mllm
206220 self .quant_nontext_module = quant_nontext_module
207221 self .extra_data_dir = extra_data_dir
208222 self .processor = processor
@@ -211,9 +225,10 @@ def __init__(
211225 self .truncation = truncation
212226 self .scheme = scheme
213227 self .device_map = device_map
228+ self .quant_lm_head = quant_lm_head
214229 self .enable_w4afp8 = self ._is_w4afp8 ()
215230
216- def _is_w4afp8 (self ):
231+ def _is_w4afp8 (self ) -> bool :
217232 return any ([v .get ("data_type" , None ) == "fp8_to_int_sym" for v in self .quant_config .values ()])
218233
219234 def prepare (self , model : torch .nn .Module , * args , ** kwargs ):
@@ -237,96 +252,75 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
237252 Returns:
238253 The quantized model.
239254 """
240- dataloader = CapturedDataloader (model .args_list , model .kwargs_list )
241- model = model .orig_model
242- if self .is_mllm :
243- rounder = AutoRoundMLLM (
244- model ,
245- tokenizer = self .tokenizer ,
246- scheme = self .scheme ,
247- processor = self .processor ,
248- image_processor = self .image_processor ,
249- layer_config = self .quant_config ,
250- batch_size = self .batch_size ,
251- amp = self .amp ,
252- device_map = self .device_map ,
253- lr_scheduler = self .lr_scheduler ,
254- dataset = dataloader ,
255- extra_data_dir = self .extra_data_dir ,
256- template = self .template ,
257- quant_nontext_module = self .quant_nontext_module ,
258- enable_quanted_input = self .enable_quanted_input ,
259- enable_minmax_tuning = self .enable_minmax_tuning ,
260- lr = self .lr ,
261- minmax_lr = self .minmax_lr ,
262- low_gpu_mem_usage = self .low_gpu_mem_usage ,
263- low_cpu_mem_usage = self .low_gpu_mem_usage ,
264- iters = self .iters ,
265- seqlen = self .seqlen ,
266- nsamples = self .nsamples ,
267- sampler = self .sampler ,
268- seed = self .seed ,
269- nblocks = self .nblocks ,
270- gradient_accumulate_steps = self .gradient_accumulate_steps ,
271- not_use_best_mse = self .not_use_best_mse ,
272- dynamic_max_gap = self .dynamic_max_gap ,
273- data_type = self .data_type ,
274- scale_dtype = self .scale_dtype ,
275- act_bits = self .act_bits ,
276- act_group_size = self .act_group_size ,
277- act_sym = self .act_sym ,
278- act_dynamic = self .act_dynamic ,
279- to_quant_block_names = self .to_quant_block_names ,
280- enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
281- truncation = self .truncation ,
282- enable_torch_compile = self .enable_torch_compile ,
283- )
255+ tokenizer = getattr (model .orig_model , "tokenizer" , None )
256+ if tokenizer is not None :
257+ delattr (model .orig_model , "tokenizer" )
284258 else :
285- rounder = AutoRound (
286- model = model ,
287- tokenizer = self .tokenizer ,
288- scheme = self .scheme ,
289- dataset = dataloader ,
290- layer_config = self .quant_config or {},
291- enable_full_range = self .enable_full_range ,
292- batch_size = self .batch_size ,
293- amp = self .amp ,
294- device_map = self .device_map ,
295- lr_scheduler = self .lr_scheduler ,
296- enable_quanted_input = self .enable_quanted_input ,
297- enable_minmax_tuning = self .enable_minmax_tuning ,
298- lr = self .lr ,
299- minmax_lr = self .minmax_lr ,
300- low_gpu_mem_usage = self .low_gpu_mem_usage ,
301- iters = self .iters ,
302- seqlen = self .seqlen ,
303- nsamples = self .nsamples ,
304- sampler = self .sampler ,
305- seed = self .seed ,
306- nblocks = self .nblocks ,
307- gradient_accumulate_steps = self .gradient_accumulate_steps ,
308- not_use_best_mse = self .not_use_best_mse ,
309- dynamic_max_gap = self .dynamic_max_gap ,
310- data_type = self .data_type ,
311- scale_dtype = self .scale_dtype ,
312- to_quant_block_names = self .to_quant_block_names ,
313- act_bits = self .act_bits ,
314- act_group_size = self .act_group_size ,
315- act_sym = self .act_sym ,
316- act_dynamic = self .act_dynamic ,
317- low_cpu_mem_usage = self .low_cpu_mem_usage ,
318- enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
319- enable_torch_compile = self .enable_torch_compile ,
320- )
321- model , weight_config = rounder .quantize ()
322- model .autoround_config = weight_config
259+ tokenizer = "Placeholder"
260+ self .dataset = CapturedDataloader (model .args_list , model .kwargs_list )
261+ model = model .orig_model
262+ rounder = AutoRound (
263+ model ,
264+ layer_config = self .layer_config ,
265+ bits = self .bits ,
266+ data_type = self .data_type ,
267+ group_size = self .group_size ,
268+ sym = self .sym ,
269+ act_bits = self .act_bits ,
270+ act_group_size = self .act_group_size ,
271+ act_sym = self .act_sym ,
272+ act_data_type = self .act_data_type ,
273+ act_dynamic = self .act_dynamic ,
274+ super_bits = self .super_bits ,
275+ super_group_size = self .super_group_size ,
276+ tokenizer = tokenizer ,
277+ scheme = self .scheme ,
278+ processor = self .processor ,
279+ image_processor = self .image_processor ,
280+ enable_full_range = self .enable_full_range ,
281+ batch_size = self .batch_size ,
282+ amp = self .amp ,
283+ device_map = self .device_map ,
284+ lr_scheduler = self .lr_scheduler ,
285+ dataset = self .dataset ,
286+ extra_data_dir = self .extra_data_dir ,
287+ template = self .template ,
288+ quant_nontext_module = self .quant_nontext_module ,
289+ enable_quanted_input = self .enable_quanted_input ,
290+ enable_minmax_tuning = self .enable_minmax_tuning ,
291+ lr = self .lr ,
292+ minmax_lr = self .minmax_lr ,
293+ low_gpu_mem_usage = self .low_gpu_mem_usage ,
294+ low_cpu_mem_usage = self .low_gpu_mem_usage ,
295+ iters = self .iters ,
296+ seqlen = self .seqlen ,
297+ nsamples = self .nsamples ,
298+ sampler = self .sampler ,
299+ seed = self .seed ,
300+ nblocks = self .nblocks ,
301+ gradient_accumulate_steps = self .gradient_accumulate_steps ,
302+ not_use_best_mse = self .not_use_best_mse ,
303+ dynamic_max_gap = self .dynamic_max_gap ,
304+ scale_dtype = self .scale_dtype ,
305+ to_quant_block_names = self .to_quant_block_names ,
306+ enable_norm_bias_tuning = self .enable_norm_bias_tuning ,
307+ truncation = self .truncation ,
308+ enable_torch_compile = self .enable_torch_compile ,
309+ quant_lm_head = self .quant_lm_head ,
310+ )
311+
323312 if self .enable_w4afp8 :
324- return rounder .save_quantized (output_dir = "temp_auto_round" , inplace = True )
313+ model , weight_config = rounder .quantize ()
314+ model .autoround_config = weight_config
315+ return rounder .save_quantized (output_dir = self .output_dir , inplace = True )
325316 elif "itrex" in self .export_format :
317+ model , weight_config = rounder .quantize ()
318+ model .autoround_config = weight_config
326319 model = pack_model (model , weight_config , device = self .device , inplace = True )
327320 else : # pragma: no cover
328- model = rounder .save_quantized (output_dir = "temp_auto_round" , format = self .export_format , inplace = True )
329-
321+ rounder .quantize_and_save (output_dir = self .output_dir , format = self .export_format , inplace = True )
322+ model = rounder .model
323+ model .autoround_config = rounder .layer_config
330324 return model
331325
332326
@@ -389,8 +383,8 @@ def get_mllm_dataloader(
389383 DataLoader: The DataLoader for the calibrated datasets.
390384 """
391385 from auto_round .calib_dataset import CALIB_DATASETS
392- from auto_round .mllm .autoround_mllm import _only_text_test
393- from auto_round .mllm .mllm_dataset import get_mllm_dataloader # pylint: disable=E0401
386+ from auto_round .compressors . mllm .compressor import _only_text_test
387+ from auto_round .compressors . mllm .dataset import get_mllm_dataloader # pylint: disable=E0401
394388
395389 template = template if template is not None else model .config .model_type
396390 template = get_template (
@@ -424,7 +418,7 @@ def get_mllm_dataloader(
424418 nsamples = (nsamples // batch_size + 1 ) * batch_size
425419 logger .warning (f"'nsamples' is not divisible by 'batch_size', will adjusted to { nsamples } " )
426420
427- dataloader , batch_size , gradient_accumulate_steps = get_mllm_dataloader (
421+ dataloader , batch_size , seqlen , gradient_accumulate_steps = get_mllm_dataloader (
428422 template = template ,
429423 processor = processor ,
430424 model = model ,
0 commit comments