|  | 
| 1 |  | -import inspect | 
| 2 |  | -from typing import Any, Dict, Optional, Tuple | 
|  | 1 | +from typing import Any, Optional | 
| 3 | 2 | 
 | 
| 4 | 3 | import torch | 
| 5 | 4 | from compressed_tensors.quantization import ( | 
| 6 | 5 |     DynamicType, | 
| 7 |  | -    KVCacheScaleType, | 
| 8 | 6 |     QuantizationArgs, | 
| 9 |  | -    QuantizationScheme, | 
| 10 | 7 |     QuantizationStatus, | 
| 11 | 8 |     QuantizationStrategy, | 
| 12 | 9 | ) | 
| 13 | 10 | from compressed_tensors.quantization.lifecycle.forward import forward_quantize | 
| 14 |  | -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme | 
| 15 | 11 | from compressed_tensors.utils import align_module_device, update_offload_parameter | 
| 16 | 12 | from loguru import logger | 
| 17 | 13 | from torch.nn import Module | 
| 18 | 14 | 
 | 
| 19 |  | -from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache | 
| 20 | 15 | from llmcompressor.observers import Observer | 
| 21 | 16 | from llmcompressor.utils.helpers import getattr_chain | 
| 22 | 17 | 
 | 
|  | 
| 25 | 20 |     "update_weight_zp_scale", | 
| 26 | 21 |     "calibrate_input_hook", | 
| 27 | 22 |     "calibrate_output_hook", | 
| 28 |  | -    "calibrate_kv_cache_input_hook", | 
| 29 |  | -    "calibrate_kv_cache_output_hook", | 
| 30 |  | -    "initialize_quantized_kv_cache", | 
| 31 | 23 |     "freeze_module_quantization", | 
| 32 | 24 |     "apply_calibration_status", | 
| 33 | 25 |     "reset_quantization_status", | 
| 34 | 26 |     "update_weight_global_scale", | 
|  | 27 | +    "calibrate_query_hook", | 
|  | 28 | +    "calibrate_key_hook", | 
|  | 29 | +    "calibrate_value_hook", | 
| 35 | 30 | ] | 
| 36 | 31 | 
 | 
| 37 | 32 | 
 | 
| @@ -151,8 +146,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): | 
| 151 | 146 |     if value.numel() == 0: | 
| 152 | 147 |         return | 
| 153 | 148 | 
 | 
| 154 |  | -    quantization_scheme = getattr(module, "quantization_scheme", None) | 
| 155 |  | -    quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None) | 
|  | 149 | +    field_name = "input" if base_name != "output" else "output"  # input,q,k,v,output | 
|  | 150 | +    args_attr = f"quantization_scheme.{field_name}_activations" | 
|  | 151 | +    quantization_args = getattr_chain(module, args_attr, None) | 
| 156 | 152 | 
 | 
| 157 | 153 |     calculate_qparams = True | 
| 158 | 154 |     calculate_gparam = False | 
| @@ -202,60 +198,16 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): | 
| 202 | 198 |     return output | 
| 203 | 199 | 
 | 
| 204 | 200 | 
 | 
| 205 |  | -def calibrate_kv_cache_input_hook( | 
| 206 |  | -    module: Module, args: Any, kwargs: Dict[str, Any] | 
| 207 |  | -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: | 
| 208 |  | -    """ | 
| 209 |  | -    Hook to update inputs to attention layers when running | 
| 210 |  | -    kv_cache quantization. Will update the passed in | 
| 211 |  | -    kv_cache to singleton QuantizedKVParameterCache. | 
| 212 |  | -    """ | 
| 213 |  | -    kv_cache = getattr(module, "kv_cache") | 
| 214 |  | -    if not hasattr(module, "_past_kv_name"): | 
| 215 |  | -        # Determine which past KV parameter name to use once and cache it | 
| 216 |  | -        # TODO: Find a better place to cache this | 
| 217 |  | -        module._past_kv_name = ( | 
| 218 |  | -            "past_key_value"  # transformers#39956 | 
| 219 |  | -            if "past_key_value" in inspect.signature(module.forward).parameters | 
| 220 |  | -            else "past_key_values" | 
| 221 |  | -        ) | 
| 222 |  | - | 
| 223 |  | -    kwargs[module._past_kv_name] = kv_cache | 
| 224 |  | -    kwargs["use_cache"] = False | 
| 225 |  | -    return args, kwargs | 
|  | 201 | +def calibrate_query_hook(module: Module, query_states: torch.Tensor): | 
|  | 202 | +    calibrate_activations(module, query_states, base_name="q") | 
| 226 | 203 | 
 | 
| 227 | 204 | 
 | 
| 228 |  | -def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor): | 
| 229 |  | -    """ | 
| 230 |  | -    Hook to update k_scale and v_scale parameters when running kv_cache quantization. | 
| 231 |  | -    """ | 
| 232 |  | -    kv_cache = getattr(module, "kv_cache") | 
| 233 |  | -    k_scale = kv_cache.k_scales[module.layer_idx] | 
| 234 |  | -    v_scale = kv_cache.v_scales[module.layer_idx] | 
| 235 |  | -    update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale) | 
| 236 |  | -    update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale) | 
|  | 205 | +def calibrate_key_hook(module: Module, key_states: torch.Tensor): | 
|  | 206 | +    calibrate_activations(module, key_states, base_name="k") | 
| 237 | 207 | 
 | 
| 238 | 208 | 
 | 
| 239 |  | -def initialize_quantized_kv_cache(module: Module): | 
| 240 |  | -    """ | 
| 241 |  | -    Initialize a quantized kv_cache on a module (analogous to initializing an observer) | 
| 242 |  | -    When a config specifying kv_cache quantization is applied to a model, the kv_cache | 
| 243 |  | -    args are redefined as the output_activations targeting attention modules. | 
| 244 |  | -
 | 
| 245 |  | -    This function should be called on attention modules with output_activations | 
| 246 |  | -    """ | 
| 247 |  | -    scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None) | 
| 248 |  | -    existing_kv_cache = getattr(module, "kv_cache", None) | 
| 249 |  | - | 
| 250 |  | -    if ( | 
| 251 |  | -        scheme is None | 
| 252 |  | -        or not is_kv_cache_quant_scheme(scheme) | 
| 253 |  | -        or isinstance(existing_kv_cache, QuantizedKVParameterCache) | 
| 254 |  | -    ): | 
| 255 |  | -        return | 
| 256 |  | - | 
| 257 |  | -    quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations) | 
| 258 |  | -    setattr(module, "kv_cache", quantized_kv_cache) | 
|  | 209 | +def calibrate_value_hook(module: Module, value_states: torch.Tensor): | 
|  | 210 | +    calibrate_activations(module, value_states, base_name="v") | 
| 259 | 211 | 
 | 
| 260 | 212 | 
 | 
| 261 | 213 | def apply_calibration_status(module: Module): | 
| @@ -284,16 +236,11 @@ def freeze_module_quantization(module: Module): | 
| 284 | 236 |         return | 
| 285 | 237 | 
 | 
| 286 | 238 |     # remove observers | 
| 287 |  | -    for name in ("input", "weight", "output"): | 
|  | 239 | +    for name in ("input", "weight", "output", "q", "k", "v"): | 
| 288 | 240 |         obs_name = f"{name}_observer" | 
| 289 | 241 |         if hasattr(module, obs_name): | 
| 290 | 242 |             delattr(module, obs_name) | 
| 291 | 243 | 
 | 
| 292 |  | -    # remove quantized kv_cache | 
| 293 |  | -    kv_cache = getattr(module, "kv_cache", None) | 
| 294 |  | -    if isinstance(kv_cache, QuantizedKVParameterCache): | 
| 295 |  | -        delattr(module, "kv_cache") | 
| 296 |  | - | 
| 297 | 244 |     module.quantization_status = QuantizationStatus.FROZEN | 
| 298 | 245 | 
 | 
| 299 | 246 | 
 | 
|  | 
0 commit comments