|  | 
| 28 | 28 |     initialize_module_for_quantization, | 
| 29 | 29 |     is_attention_module, | 
| 30 | 30 | ) | 
|  | 31 | +from compressed_tensors.quantization.quant_args import QuantizationArgs | 
| 31 | 32 | from compressed_tensors.quantization.quant_config import ( | 
| 32 | 33 |     QuantizationConfig, | 
| 33 | 34 |     QuantizationStatus, | 
| @@ -128,21 +129,11 @@ def apply_quantization_config( | 
| 128 | 129 |     # force zero points during initialization | 
| 129 | 130 |     force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED | 
| 130 | 131 | 
 | 
| 131 |  | -    # apply kv cache quantization before any attention quantization | 
| 132 |  | -    # because attention quantization is a superset of kv cache quantization | 
|  | 132 | +    # apply and initialize kv cache quantization | 
| 133 | 133 |     if config.kv_cache_scheme is not None: | 
| 134 |  | -        scheme = QuantizationScheme( | 
| 135 |  | -            targets=[".*self_attn$"], input_activations=config.kv_cache_scheme | 
|  | 134 | +        _apply_kv_cache_scheme( | 
|  | 135 | +            model, config.kv_cache_scheme, config.quantization_status, force_zero_point | 
| 136 | 136 |         ) | 
| 137 |  | -        for submodule in model.modules(): | 
| 138 |  | -            if is_attention_module(submodule): | 
| 139 |  | -                submodule.quantization_scheme = scheme | 
| 140 |  | -                initialize_hooked_kv_cache(model, submodule) | 
| 141 |  | -                initialize_module_for_quantization( | 
| 142 |  | -                    submodule, | 
| 143 |  | -                    force_zero_point=force_zero_point, | 
| 144 |  | -                ) | 
| 145 |  | -                submodule.quantization_status = config.quantization_status | 
| 146 | 137 | 
 | 
| 147 | 138 |     # build mapping of targets to schemes for easier matching | 
| 148 | 139 |     # use ordered dict to preserve target ordering in config | 
| @@ -191,6 +182,29 @@ def apply_quantization_config( | 
| 191 | 182 |         submodule.quantization_status = config.quantization_status | 
| 192 | 183 | 
 | 
| 193 | 184 | 
 | 
|  | 185 | +def _apply_kv_cache_scheme( | 
|  | 186 | +    model: torch.nn.Module, | 
|  | 187 | +    kv_cache_scheme: QuantizationArgs, | 
|  | 188 | +    status: QuantizationStatus, | 
|  | 189 | +    force_zero_point: bool, | 
|  | 190 | +): | 
|  | 191 | +    # applies and initializes kv cache quantization | 
|  | 192 | +    # this step cannot come after attention apply/initialize | 
|  | 193 | +    # otherwise it will override the attention qparams | 
|  | 194 | +    scheme = QuantizationScheme( | 
|  | 195 | +        targets=[".*self_attn$"], input_activations=kv_cache_scheme | 
|  | 196 | +    ) | 
|  | 197 | +    for submodule in model.modules(): | 
|  | 198 | +        if is_attention_module(submodule): | 
|  | 199 | +            submodule.quantization_scheme = scheme | 
|  | 200 | +            initialize_hooked_kv_cache(model, submodule) | 
|  | 201 | +            initialize_module_for_quantization( | 
|  | 202 | +                submodule, | 
|  | 203 | +                force_zero_point=force_zero_point, | 
|  | 204 | +            ) | 
|  | 205 | +            submodule.quantization_status = status | 
|  | 206 | + | 
|  | 207 | + | 
| 194 | 208 | def _load_quant_args_from_mapping( | 
| 195 | 209 |     base_name: str, module_name: str, module: Module, mapping: Dict | 
| 196 | 210 | ): | 
|  | 
0 commit comments