Skip to content

Commit b85337f

Browse files
committed
squash
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a0b83b4 commit b85337f

File tree

7 files changed

+77
-112
lines changed

7 files changed

+77
-112
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def localversion_func(version: ScmVersion) -> str:
160160
"torchvision",
161161
"librosa==0.11.0",
162162
"soundfile",
163-
"torchcodec",
163+
#"torchcodec",
164164
# linting, formatting, and type checking
165165
"mypy~=1.10.0",
166166
"ruff~=0.4.8",
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa
22

3-
from .cache import *
43
from .gptq import *
54
from .quantization import *

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 14 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
1-
import inspect
2-
from typing import Any, Dict, Optional, Tuple
1+
from typing import Any, Optional
32

43
import torch
54
from compressed_tensors.quantization import (
65
DynamicType,
7-
KVCacheScaleType,
86
QuantizationArgs,
9-
QuantizationScheme,
107
QuantizationStatus,
118
QuantizationStrategy,
129
)
1310
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
14-
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
1511
from compressed_tensors.utils import align_module_device, update_offload_parameter
1612
from loguru import logger
1713
from torch.nn import Module
1814

19-
from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
2015
from llmcompressor.observers import Observer
2116
from llmcompressor.utils.helpers import getattr_chain
2217

@@ -25,13 +20,13 @@
2520
"update_weight_zp_scale",
2621
"calibrate_input_hook",
2722
"calibrate_output_hook",
28-
"calibrate_kv_cache_input_hook",
29-
"calibrate_kv_cache_output_hook",
30-
"initialize_quantized_kv_cache",
3123
"freeze_module_quantization",
3224
"apply_calibration_status",
3325
"reset_quantization_status",
3426
"update_weight_global_scale",
27+
"calibrate_query_hook",
28+
"calibrate_key_hook",
29+
"calibrate_value_hook",
3530
]
3631

3732

@@ -151,8 +146,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
151146
if value.numel() == 0:
152147
return
153148

154-
quantization_scheme = getattr(module, "quantization_scheme", None)
155-
quantization_args = getattr(quantization_scheme, f"{base_name}_activations", None)
149+
field_name = "input" if base_name != "output" else "output" # input,q,k,v,output
150+
args_attr = f"quantization_scheme.{field_name}_activations"
151+
quantization_args = getattr_chain(module, args_attr, None)
156152

157153
calculate_qparams = True
158154
calculate_gparam = False
@@ -202,60 +198,16 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
202198
return output
203199

204200

205-
def calibrate_kv_cache_input_hook(
206-
module: Module, args: Any, kwargs: Dict[str, Any]
207-
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
208-
"""
209-
Hook to update inputs to attention layers when running
210-
kv_cache quantization. Will update the passed in
211-
kv_cache to singleton QuantizedKVParameterCache.
212-
"""
213-
kv_cache = getattr(module, "kv_cache")
214-
if not hasattr(module, "_past_kv_name"):
215-
# Determine which past KV parameter name to use once and cache it
216-
# TODO: Find a better place to cache this
217-
module._past_kv_name = (
218-
"past_key_value" # transformers#39956
219-
if "past_key_value" in inspect.signature(module.forward).parameters
220-
else "past_key_values"
221-
)
222-
223-
kwargs[module._past_kv_name] = kv_cache
224-
kwargs["use_cache"] = False
225-
return args, kwargs
201+
def calibrate_query_hook(module: Module, query_states: torch.Tensor):
202+
calibrate_activations(module, query_states, base_name="q")
226203

227204

228-
def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
229-
"""
230-
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
231-
"""
232-
kv_cache = getattr(module, "kv_cache")
233-
k_scale = kv_cache.k_scales[module.layer_idx]
234-
v_scale = kv_cache.v_scales[module.layer_idx]
235-
update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale)
236-
update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale)
205+
def calibrate_key_hook(module: Module, key_states: torch.Tensor):
206+
calibrate_activations(module, key_states, base_name="k")
237207

238208

239-
def initialize_quantized_kv_cache(module: Module):
240-
"""
241-
Initialize a quantized kv_cache on a module (analogous to initializing an observer)
242-
When a config specifying kv_cache quantization is applied to a model, the kv_cache
243-
args are redefined as the output_activations targeting attention modules.
244-
245-
This function should be called on attention modules with output_activations
246-
"""
247-
scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
248-
existing_kv_cache = getattr(module, "kv_cache", None)
249-
250-
if (
251-
scheme is None
252-
or not is_kv_cache_quant_scheme(scheme)
253-
or isinstance(existing_kv_cache, QuantizedKVParameterCache)
254-
):
255-
return
256-
257-
quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
258-
setattr(module, "kv_cache", quantized_kv_cache)
209+
def calibrate_value_hook(module: Module, value_states: torch.Tensor):
210+
calibrate_activations(module, value_states, base_name="v")
259211

260212

261213
def apply_calibration_status(module: Module):
@@ -284,16 +236,11 @@ def freeze_module_quantization(module: Module):
284236
return
285237

286238
# remove observers
287-
for name in ("input", "weight", "output"):
239+
for name in ("input", "weight", "output", "q", "k", "v"):
288240
obs_name = f"{name}_observer"
289241
if hasattr(module, obs_name):
290242
delattr(module, obs_name)
291243

292-
# remove quantized kv_cache
293-
kv_cache = getattr(module, "kv_cache", None)
294-
if isinstance(kv_cache, QuantizedKVParameterCache):
295-
delattr(module, "kv_cache")
296-
297244
module.quantization_status = QuantizationStatus.FROZEN
298245

299246

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from typing import Any, Dict, List, Optional, Set, Union
22

33
import torch
4+
from compressed_tensors.modeling import (
5+
IMPL_ATTR,
6+
KV_CACHE_ATTR,
7+
register_key_hook,
8+
register_query_hook,
9+
register_value_hook,
10+
)
411
from compressed_tensors.quantization import (
512
DynamicType,
613
QuantizationArgs,
@@ -21,12 +28,12 @@
2128
from llmcompressor.modifiers.quantization.calibration import (
2229
apply_calibration_status,
2330
calibrate_input_hook,
24-
calibrate_kv_cache_input_hook,
25-
calibrate_kv_cache_output_hook,
31+
calibrate_key_hook,
2632
calibrate_output_hook,
33+
calibrate_query_hook,
34+
calibrate_value_hook,
2735
freeze_module_quantization,
2836
initialize_observer,
29-
initialize_quantized_kv_cache,
3037
reset_quantization_status,
3138
)
3239
from llmcompressor.modifiers.utils.hooks import HooksMixin
@@ -253,19 +260,21 @@ def _initialize_observers(self, module: torch.nn.Module):
253260

254261
# input activations
255262
if input:
256-
initialize_observer(module, base_name="input")
263+
if not is_attention:
264+
initialize_observer(module, base_name="input")
265+
else:
266+
if hasattr(module, IMPL_ATTR):
267+
initialize_observer(module, base_name="q")
268+
if hasattr(module, KV_CACHE_ATTR):
269+
initialize_observer(module, base_name="k")
270+
initialize_observer(module, base_name="v")
257271

258272
# weight observers (used by `update_weight_zp_scale` or child modifier)
259273
if weight:
260274
initialize_observer(module, base_name="weight")
261275

262-
# kv_cache activations. Within `apply_quantization_config`, the config is
263-
# modified to use attention output quantization if a kv_cache_scheme exists
264-
if is_attention and output:
265-
initialize_quantized_kv_cache(module)
266-
267276
# output activations
268-
elif output:
277+
if output:
269278
initialize_observer(module, base_name="output")
270279

271280
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
@@ -284,29 +293,19 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
284293

285294
# input activations
286295
if input:
287-
hooks.add(
288-
self.register_hook(module, calibrate_input_hook, "forward_pre")
289-
)
290-
291-
# kv_cache activations. Within `apply_quantization_config`, the config is
292-
# modified to use attention output quantization if a kv_cache_scheme exists
293-
if is_attention and output:
294-
hooks.add(
295-
self.register_hook(
296-
module,
297-
calibrate_kv_cache_input_hook,
298-
"forward_pre",
299-
with_kwargs=True,
296+
if not is_attention:
297+
hooks.add(
298+
self.register_hook(module, calibrate_input_hook, "forward_pre")
300299
)
301-
)
302-
hooks.add(
303-
self.register_hook(
304-
module, calibrate_kv_cache_output_hook, "forward"
305-
)
306-
)
300+
else:
301+
if hasattr(module, IMPL_ATTR):
302+
hooks.add(register_query_hook(module, calibrate_query_hook))
303+
if hasattr(module, KV_CACHE_ATTR):
304+
hooks.add(register_key_hook(module, calibrate_key_hook))
305+
hooks.add(register_value_hook(module, calibrate_value_hook))
307306

308307
# output activations
309-
elif output:
308+
if output:
310309
hooks.add(self.register_hook(module, calibrate_output_hook, "forward"))
311310

312311
return hooks

src/llmcompressor/modifiers/utils/hooks.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
2+
from copy import deepcopy
23
from functools import wraps
3-
from typing import Any, Callable, ClassVar, Optional, Set, Union
4+
from typing import Any, Callable, ClassVar, Dict, Optional, Set, Union
45

56
import torch
67
from loguru import logger
@@ -39,6 +40,7 @@ class HooksMixin(BaseModel):
3940
# attached to global HooksMixin class
4041
_HOOKS_DISABLED: ClassVar[bool] = False
4142
_HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set()
43+
_HOOKS_TO_MODIFIER: ClassVar[Dict[RemovableHandle, "HooksMixin"]] = dict()
4244

4345
# attached to local subclasses
4446
_hooks: Set[RemovableHandle] = set()
@@ -95,6 +97,7 @@ def wrapped_hook(*args, **kwargs):
9597
register_function = getattr(target, f"register_{hook_type}_hook")
9698
handle = register_function(wrapped_hook, **kwargs)
9799
self._hooks.add(handle)
100+
self._HOOKS_TO_MODIFIER[handle] = self
98101
logger.debug(f"{self} added {handle}")
99102

100103
return handle
@@ -113,3 +116,13 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
113116
hook.remove()
114117

115118
self._hooks -= handles
119+
for handle in handles:
120+
self._HOOKS_TO_MODIFIER.pop(handle, None)
121+
122+
@classmethod
123+
def remove_hooks_by_id(cls, ids: Set[int]):
124+
handles = deepcopy(cls._HOOKS_TO_MODIFIER)
125+
for handle in handles:
126+
if handle.id in ids:
127+
modifier = cls._HOOKS_TO_MODIFIER[handle]
128+
modifier.remove_hooks(set(handle))

src/llmcompressor/observers/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def flatten_for_calibration(
5252
def _flatten_weight(
5353
value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None
5454
):
55+
# value.shape = (num_rows, num_cols)
56+
5557
if args.strategy == QuantizationStrategy.TENSOR:
5658
# (1, 1, num_weight_elems)
5759
return value.reshape((1, 1, -1))
@@ -87,6 +89,8 @@ def _flatten_weight(
8789

8890

8991
def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
92+
# value.shape = (batch_size, seq_len, hidden_dim)
93+
9094
if args.strategy == QuantizationStrategy.TENSOR:
9195
# (batch_size * seq_len, 1, hidden_dim)
9296
return value.reshape((-1, 1, value.size(-1)))
@@ -111,10 +115,11 @@ def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
111115

112116

113117
def _flatten_attention(value: torch.Tensor, args: QuantizationArgs):
118+
# value.shape = (batch_size, num_heads, seq_len, head_dim)
119+
114120
if args.strategy == QuantizationStrategy.TENSOR:
115-
# (batch_size, seq_len, num_heads, head_dim)
116121
# (batch_size * seq_len, 1, num_heads * head_dim)
117-
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
122+
return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
118123

119124
if args.strategy == QuantizationStrategy.TOKEN:
120125
raise ValueError("Token quantization cannot be applied to attention")

tests/llmcompressor/transformers/kv_cache/test_kv_cache.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55
from accelerate import init_empty_weights
6-
from compressed_tensors.quantization import KVCacheScaleType, is_attention_module
6+
from compressed_tensors.quantization import is_attention_module
77
from datasets import load_dataset
88
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
99
from transformers.utils.quantization_config import CompressedTensorsConfig
@@ -14,7 +14,7 @@
1414
NUM_CALIBRATION_SAMPLES = 16
1515
MAX_SEQUENCE_LENGTH = 512
1616
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
17-
DATASET_SPLIT = "train_sft"
17+
DATASET_SPLIT = f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"
1818

1919
MODEL_IDS = [
2020
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
@@ -49,9 +49,11 @@ def _oneshot_fixture(tmp_path: Path):
4949
symmetric=symmetric,
5050
)
5151
oneshot_args = dict(
52-
dataset="open_platypus",
5352
recipe=recipe,
54-
num_calibration_samples=16,
53+
dataset="open_platypus",
54+
splits={"calibration": f"train[:{NUM_CALIBRATION_SAMPLES}]"},
55+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
56+
max_seq_length=MAX_SEQUENCE_LENGTH,
5557
)
5658
for model_id in MODEL_IDS:
5759
oneshot_args["output_dir"] = os.path.join(tmp_path, model_id)
@@ -161,8 +163,8 @@ def test_kv_cache_model_state_dict_attr(oneshot_fixture, tmp_path):
161163
for name, submodule in model.named_modules():
162164
if is_attention_module(submodule):
163165
counts += 1
164-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
165-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
166+
assert hasattr(submodule, "v_scale")
167+
assert hasattr(submodule, "k_scale")
166168
assert counts > 0
167169

168170

@@ -200,8 +202,8 @@ def test_kv_cache_gptq_config_format(kv_cache_fixture, tmp_path):
200202
for name, submodule in model.named_modules():
201203
if is_attention_module(submodule):
202204
counts += 1
203-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
204-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
205+
assert hasattr(submodule, "v_scale")
206+
assert hasattr(submodule, "k_scale")
205207

206208
assert counts > 0
207209

@@ -240,7 +242,7 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path):
240242
for name, submodule in model.named_modules():
241243
if is_attention_module(submodule):
242244
counts += 1
243-
assert hasattr(submodule, KVCacheScaleType.VALUE.value)
244-
assert hasattr(submodule, KVCacheScaleType.KEY.value)
245+
assert hasattr(submodule, "v_scale")
246+
assert hasattr(submodule, "k_scale")
245247

246248
assert counts > 0

0 commit comments

Comments
 (0)