Skip to content

Commit 1317187

Browse files
Adapt AutoRound v0.8.0 [For MLLM] (#2291)
Signed-off-by: Kaihui-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1277d2d commit 1317187

File tree

10 files changed

+297
-182
lines changed

10 files changed

+297
-182
lines changed

examples/pytorch/multimodal-modeling/quantization/auto_round/mllm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def tune(args):
223223
use_auto_mapping = True
224224

225225
woq_config = AutoRoundConfig(
226-
is_vlm=True,
227226
bits=args.bits,
228227
sym=not args.asym,
229228
group_size=args.group_size,

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 102 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def _is_auto_round_available():
3434
_is_auto_round_available()
3535

3636
from auto_round import AutoRound, AutoRoundMLLM # pylint: disable=E0401
37+
from auto_round.compressors.mllm.eval import lmms_eval, mllm_eval
38+
from auto_round.compressors.mllm.template import Template, get_template
3739
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
38-
from auto_round.mllm import lmms_eval, mllm_eval
39-
from auto_round.mllm.template import Template, get_template
4040
from auto_round.schemes import QuantizationScheme
4141

4242
from neural_compressor.torch.algorithms import Quantizer
@@ -50,11 +50,24 @@ class AutoRoundQuantizer(Quantizer):
5050

5151
def __init__(
5252
self,
53-
quant_config: dict = {},
53+
bits: int = None,
54+
group_size: int = None,
55+
sym: bool = None,
56+
data_type: str = None,
57+
act_bits: int = None,
58+
act_group_size: int = None,
59+
act_sym: bool = None,
60+
act_data_type: str = None,
61+
act_dynamic: bool = None,
62+
super_bits: int = None,
63+
super_group_size: int = None,
64+
quant_config: dict = {}, # for INC
65+
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
5466
enable_full_range: bool = False, ##for symmetric, TODO support later
5567
batch_size: int = 8,
5668
amp: bool = True,
5769
device_map: str = None,
70+
quant_lm_head: bool = False,
5871
lr_scheduler=None,
5972
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
6073
enable_quanted_input: bool = True,
@@ -71,21 +84,14 @@ def __init__(
7184
gradient_accumulate_steps: int = 1,
7285
not_use_best_mse: bool = False,
7386
dynamic_max_gap: int = -1,
74-
data_type: str = "int",
7587
scale_dtype: str = "fp16",
7688
to_quant_block_names: list = None,
77-
act_bits: int = 32,
78-
act_group_size: int = None,
79-
act_sym: bool = None,
80-
act_dynamic: bool = True,
81-
act_data_type: Optional[str] = None,
8289
low_cpu_mem_usage: bool = False,
8390
export_format: str = "itrex",
8491
# v0.4
8592
enable_norm_bias_tuning: bool = False,
8693
enable_torch_compile: bool = None,
8794
# mllm
88-
is_mllm: bool = False,
8995
quant_nontext_module: bool = False,
9096
extra_data_dir: str = None,
9197
image_processor=None,
@@ -119,13 +125,15 @@ def __init__(
119125
bits (int): Number of bits for quantization (default is 4).
120126
group_size (int): Size of the quantization group (default is 128).
121127
sym (bool): Whether to use symmetric quantization. (default is None).
128+
layer_config (dict, optional): Layer-wise quantization config. Defaults to None.
122129
bits (int): Number of bits for quantization (default is 4).
123130
group_size (int): Size of the quantization group (default is 128).
124131
sym (bool): Whether symmetric quantization is to be used (default is False).
125132
enable_full_range (bool): Whether to enable full range quantization (default is False).
126133
batch_size (int): Batch size for training (default is 8).
127134
amp (bool): Whether to use automatic mixed precision (default is True).
128135
device_map: The device to be used for tuning (default is None).
136+
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers. (default is False).
129137
lr_scheduler: The learning rate scheduler to be used.
130138
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
131139
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
@@ -155,7 +163,6 @@ def __init__(
155163
enable_norm_bias_tuning (bool): Whether to enable fast norm/layer_bias tuning.
156164
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True.
157165
quant_nontext_module (bool): Whether to quantize nontext module.
158-
is_mllm (bool): Indicates whether the model to be quantized is a multi-modal model (MLLM).
159166
extra_data_dir (str): The path for extra data such as images, audio or videos.
160167
processor (transformers.AutoProcessor): Any multi-modal model will require an object to encode or
161168
decode the data that groups several modalities (among text, vision and audio).
@@ -170,12 +177,26 @@ def __init__(
170177
The quantized model.
171178
"""
172179
super().__init__(quant_config)
173-
self.tokenizer = "Placeholder" # for AutoRound initialization
180+
self.layer_config = layer_config
181+
self.output_dir = kwargs.pop("output_dir", "temp_auto_round")
182+
self.tokenizer = kwargs.pop("tokenizer", "Placeholder") # for AutoRound initialization
174183
self.enable_full_range = enable_full_range
184+
self.bits = bits
185+
self.group_size = group_size
186+
self.sym = sym
187+
self.data_type = data_type
188+
self.act_bits = act_bits
189+
self.act_group_size = act_group_size
190+
self.act_sym = act_sym
191+
self.act_data_type = act_data_type
192+
self.act_dynamic = act_dynamic
193+
self.super_bits = super_bits
194+
self.super_group_size = super_group_size
175195
self.batch_size = batch_size
176196
self.amp = amp
177197
self.device = get_accelerator(kwargs.pop("device", "auto")).name()
178198
self.lr_scheduler = lr_scheduler
199+
self.dataset = dataset
179200
self.enable_quanted_input = enable_quanted_input
180201
self.enable_minmax_tuning = enable_minmax_tuning
181202
self.lr = lr
@@ -190,19 +211,12 @@ def __init__(
190211
self.gradient_accumulate_steps = gradient_accumulate_steps
191212
self.not_use_best_mse = not_use_best_mse
192213
self.dynamic_max_gap = dynamic_max_gap
193-
self.data_type = data_type
194214
self.scale_dtype = scale_dtype
195215
self.to_quant_block_names = to_quant_block_names
196-
self.act_bits = act_bits
197-
self.act_group_size = act_group_size
198-
self.act_sym = act_sym
199-
self.act_dynamic = act_dynamic
200-
self.act_data_type = act_data_type
201216
self.low_cpu_mem_usage = low_cpu_mem_usage
202217
self.export_format = export_format
203218
self.enable_norm_bias_tuning = enable_norm_bias_tuning
204219
self.enable_torch_compile = enable_torch_compile
205-
self.is_mllm = is_mllm
206220
self.quant_nontext_module = quant_nontext_module
207221
self.extra_data_dir = extra_data_dir
208222
self.processor = processor
@@ -211,9 +225,10 @@ def __init__(
211225
self.truncation = truncation
212226
self.scheme = scheme
213227
self.device_map = device_map
228+
self.quant_lm_head = quant_lm_head
214229
self.enable_w4afp8 = self._is_w4afp8()
215230

216-
def _is_w4afp8(self):
231+
def _is_w4afp8(self) -> bool:
217232
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
218233

219234
def prepare(self, model: torch.nn.Module, *args, **kwargs):
@@ -237,96 +252,75 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
237252
Returns:
238253
The quantized model.
239254
"""
240-
dataloader = CapturedDataloader(model.args_list, model.kwargs_list)
241-
model = model.orig_model
242-
if self.is_mllm:
243-
rounder = AutoRoundMLLM(
244-
model,
245-
tokenizer=self.tokenizer,
246-
scheme=self.scheme,
247-
processor=self.processor,
248-
image_processor=self.image_processor,
249-
layer_config=self.quant_config,
250-
batch_size=self.batch_size,
251-
amp=self.amp,
252-
device_map=self.device_map,
253-
lr_scheduler=self.lr_scheduler,
254-
dataset=dataloader,
255-
extra_data_dir=self.extra_data_dir,
256-
template=self.template,
257-
quant_nontext_module=self.quant_nontext_module,
258-
enable_quanted_input=self.enable_quanted_input,
259-
enable_minmax_tuning=self.enable_minmax_tuning,
260-
lr=self.lr,
261-
minmax_lr=self.minmax_lr,
262-
low_gpu_mem_usage=self.low_gpu_mem_usage,
263-
low_cpu_mem_usage=self.low_gpu_mem_usage,
264-
iters=self.iters,
265-
seqlen=self.seqlen,
266-
nsamples=self.nsamples,
267-
sampler=self.sampler,
268-
seed=self.seed,
269-
nblocks=self.nblocks,
270-
gradient_accumulate_steps=self.gradient_accumulate_steps,
271-
not_use_best_mse=self.not_use_best_mse,
272-
dynamic_max_gap=self.dynamic_max_gap,
273-
data_type=self.data_type,
274-
scale_dtype=self.scale_dtype,
275-
act_bits=self.act_bits,
276-
act_group_size=self.act_group_size,
277-
act_sym=self.act_sym,
278-
act_dynamic=self.act_dynamic,
279-
to_quant_block_names=self.to_quant_block_names,
280-
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
281-
truncation=self.truncation,
282-
enable_torch_compile=self.enable_torch_compile,
283-
)
255+
tokenizer = getattr(model.orig_model, "tokenizer", None)
256+
if tokenizer is not None:
257+
delattr(model.orig_model, "tokenizer")
284258
else:
285-
rounder = AutoRound(
286-
model=model,
287-
tokenizer=self.tokenizer,
288-
scheme=self.scheme,
289-
dataset=dataloader,
290-
layer_config=self.quant_config or {},
291-
enable_full_range=self.enable_full_range,
292-
batch_size=self.batch_size,
293-
amp=self.amp,
294-
device_map=self.device_map,
295-
lr_scheduler=self.lr_scheduler,
296-
enable_quanted_input=self.enable_quanted_input,
297-
enable_minmax_tuning=self.enable_minmax_tuning,
298-
lr=self.lr,
299-
minmax_lr=self.minmax_lr,
300-
low_gpu_mem_usage=self.low_gpu_mem_usage,
301-
iters=self.iters,
302-
seqlen=self.seqlen,
303-
nsamples=self.nsamples,
304-
sampler=self.sampler,
305-
seed=self.seed,
306-
nblocks=self.nblocks,
307-
gradient_accumulate_steps=self.gradient_accumulate_steps,
308-
not_use_best_mse=self.not_use_best_mse,
309-
dynamic_max_gap=self.dynamic_max_gap,
310-
data_type=self.data_type,
311-
scale_dtype=self.scale_dtype,
312-
to_quant_block_names=self.to_quant_block_names,
313-
act_bits=self.act_bits,
314-
act_group_size=self.act_group_size,
315-
act_sym=self.act_sym,
316-
act_dynamic=self.act_dynamic,
317-
low_cpu_mem_usage=self.low_cpu_mem_usage,
318-
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
319-
enable_torch_compile=self.enable_torch_compile,
320-
)
321-
model, weight_config = rounder.quantize()
322-
model.autoround_config = weight_config
259+
tokenizer = "Placeholder"
260+
self.dataset = CapturedDataloader(model.args_list, model.kwargs_list)
261+
model = model.orig_model
262+
rounder = AutoRound(
263+
model,
264+
layer_config=self.layer_config,
265+
bits=self.bits,
266+
data_type=self.data_type,
267+
group_size=self.group_size,
268+
sym=self.sym,
269+
act_bits=self.act_bits,
270+
act_group_size=self.act_group_size,
271+
act_sym=self.act_sym,
272+
act_data_type=self.act_data_type,
273+
act_dynamic=self.act_dynamic,
274+
super_bits=self.super_bits,
275+
super_group_size=self.super_group_size,
276+
tokenizer=tokenizer,
277+
scheme=self.scheme,
278+
processor=self.processor,
279+
image_processor=self.image_processor,
280+
enable_full_range=self.enable_full_range,
281+
batch_size=self.batch_size,
282+
amp=self.amp,
283+
device_map=self.device_map,
284+
lr_scheduler=self.lr_scheduler,
285+
dataset=self.dataset,
286+
extra_data_dir=self.extra_data_dir,
287+
template=self.template,
288+
quant_nontext_module=self.quant_nontext_module,
289+
enable_quanted_input=self.enable_quanted_input,
290+
enable_minmax_tuning=self.enable_minmax_tuning,
291+
lr=self.lr,
292+
minmax_lr=self.minmax_lr,
293+
low_gpu_mem_usage=self.low_gpu_mem_usage,
294+
low_cpu_mem_usage=self.low_gpu_mem_usage,
295+
iters=self.iters,
296+
seqlen=self.seqlen,
297+
nsamples=self.nsamples,
298+
sampler=self.sampler,
299+
seed=self.seed,
300+
nblocks=self.nblocks,
301+
gradient_accumulate_steps=self.gradient_accumulate_steps,
302+
not_use_best_mse=self.not_use_best_mse,
303+
dynamic_max_gap=self.dynamic_max_gap,
304+
scale_dtype=self.scale_dtype,
305+
to_quant_block_names=self.to_quant_block_names,
306+
enable_norm_bias_tuning=self.enable_norm_bias_tuning,
307+
truncation=self.truncation,
308+
enable_torch_compile=self.enable_torch_compile,
309+
quant_lm_head=self.quant_lm_head,
310+
)
311+
323312
if self.enable_w4afp8:
324-
return rounder.save_quantized(output_dir="temp_auto_round", inplace=True)
313+
model, weight_config = rounder.quantize()
314+
model.autoround_config = weight_config
315+
return rounder.save_quantized(output_dir=self.output_dir, inplace=True)
325316
elif "itrex" in self.export_format:
317+
model, weight_config = rounder.quantize()
318+
model.autoround_config = weight_config
326319
model = pack_model(model, weight_config, device=self.device, inplace=True)
327320
else: # pragma: no cover
328-
model = rounder.save_quantized(output_dir="temp_auto_round", format=self.export_format, inplace=True)
329-
321+
rounder.quantize_and_save(output_dir=self.output_dir, format=self.export_format, inplace=True)
322+
model = rounder.model
323+
model.autoround_config = rounder.layer_config
330324
return model
331325

332326

@@ -389,8 +383,8 @@ def get_mllm_dataloader(
389383
DataLoader: The DataLoader for the calibrated datasets.
390384
"""
391385
from auto_round.calib_dataset import CALIB_DATASETS
392-
from auto_round.mllm.autoround_mllm import _only_text_test
393-
from auto_round.mllm.mllm_dataset import get_mllm_dataloader # pylint: disable=E0401
386+
from auto_round.compressors.mllm.compressor import _only_text_test
387+
from auto_round.compressors.mllm.dataset import get_mllm_dataloader # pylint: disable=E0401
394388

395389
template = template if template is not None else model.config.model_type
396390
template = get_template(
@@ -424,7 +418,7 @@ def get_mllm_dataloader(
424418
nsamples = (nsamples // batch_size + 1) * batch_size
425419
logger.warning(f"'nsamples' is not divisible by 'batch_size', will adjusted to {nsamples}")
426420

427-
dataloader, batch_size, gradient_accumulate_steps = get_mllm_dataloader(
421+
dataloader, batch_size, seqlen, gradient_accumulate_steps = get_mllm_dataloader(
428422
template=template,
429423
processor=processor,
430424
model=model,

0 commit comments

Comments
 (0)