Skip to content

Commit bfa1cd1

Browse files
committed
break out function
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8995d84 commit bfa1cd1

File tree

1 file changed

+27
-13
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+27
-13
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
initialize_module_for_quantization,
2929
is_attention_module,
3030
)
31+
from compressed_tensors.quantization.quant_args import QuantizationArgs
3132
from compressed_tensors.quantization.quant_config import (
3233
QuantizationConfig,
3334
QuantizationStatus,
@@ -128,21 +129,11 @@ def apply_quantization_config(
128129
# force zero points during initialization
129130
force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED
130131

131-
# apply kv cache quantization before any attention quantization
132-
# because attention quantization is a superset of kv cache quantization
132+
# apply and initialize kv cache quantization
133133
if config.kv_cache_scheme is not None:
134-
scheme = QuantizationScheme(
135-
targets=[".*self_attn$"], input_activations=config.kv_cache_scheme
134+
_apply_kv_cache_scheme(
135+
model, config.kv_cache_scheme, config.quantization_status, force_zero_point
136136
)
137-
for submodule in model.modules():
138-
if is_attention_module(submodule):
139-
submodule.quantization_scheme = scheme
140-
initialize_hooked_kv_cache(model, submodule)
141-
initialize_module_for_quantization(
142-
submodule,
143-
force_zero_point=force_zero_point,
144-
)
145-
submodule.quantization_status = config.quantization_status
146137

147138
# build mapping of targets to schemes for easier matching
148139
# use ordered dict to preserve target ordering in config
@@ -191,6 +182,29 @@ def apply_quantization_config(
191182
submodule.quantization_status = config.quantization_status
192183

193184

185+
def _apply_kv_cache_scheme(
186+
model: torch.nn.Module,
187+
kv_cache_scheme: QuantizationArgs,
188+
status: QuantizationStatus,
189+
force_zero_point: bool,
190+
):
191+
# applies and initializes kv cache quantization
192+
# this step cannot come after attention apply/initialize
193+
# otherwise it will override the attention qparams
194+
scheme = QuantizationScheme(
195+
targets=[".*self_attn$"], input_activations=kv_cache_scheme
196+
)
197+
for submodule in model.modules():
198+
if is_attention_module(submodule):
199+
submodule.quantization_scheme = scheme
200+
initialize_hooked_kv_cache(model, submodule)
201+
initialize_module_for_quantization(
202+
submodule,
203+
force_zero_point=force_zero_point,
204+
)
205+
submodule.quantization_status = status
206+
207+
194208
def _load_quant_args_from_mapping(
195209
base_name: str, module_name: str, module: Module, mapping: Dict
196210
):

0 commit comments

Comments
 (0)