Skip to content

Commit 05ec17e

Browse files
committed
WIP
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 493b92d commit 05ec17e

File tree

5 files changed

+92
-87
lines changed

5 files changed

+92
-87
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
# isort: off
17+
from .kvcache import *
18+
from .attention import *

src/compressed_tensors/modeling/attention.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
QuantizationStrategy,
2424
forward_quantize,
2525
)
26-
from compressed_tensors.quantization.lifecycle.initialize import (
27-
_initialize_scale_zero_point,
28-
)
2926
from compressed_tensors.utils import getattr_chain
3027
from compressed_tensors.utils.internal import InternalModule
3128
from torch import Tensor
@@ -39,6 +36,7 @@
3936
"QuantizedAttentionImpl",
4037
"initialize_hooked_attention",
4138
"register_query_hook",
39+
"IMPL_ATTR",
4240
]
4341

4442

@@ -94,33 +92,6 @@ def forward(
9492
**kwargs,
9593
)
9694

97-
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
98-
"""
99-
Initialize attention quantization parameters if they have not already been
100-
initialized. KV cache quantization parameters are initialized by the
101-
`QuantizedKVCache`
102-
103-
:param model: parent model of attention module
104-
:param module: attention module to initialize with
105-
"""
106-
# TODO: move to initialize.py
107-
assert module is self.attn_module()
108-
scheme: Optional[QuantizationScheme] = getattr(
109-
module, "quantization_scheme", None
110-
)
111-
quant_args: Optional[QuantizationArgs] = getattr(
112-
scheme, "input_activations", None
113-
)
114-
115-
if (
116-
not self._qparams_initialized
117-
and quant_args is not None
118-
and not scheme.kv_cache_only
119-
):
120-
assert quant_args.strategy == QuantizationStrategy.TENSOR
121-
_initialize_scale_zero_point(module, "q", quant_args)
122-
self._qparams_initialized = True
123-
12495

12596
# ----- initialize ----- #
12697

@@ -141,7 +112,6 @@ def initialize_hooked_attention(
141112
142113
:param model: parent model of attention module
143114
:param module: attention module to initialize with
144-
:param quantize: initialize attention quantization parameters
145115
"""
146116
if not hasattr(module, IMPL_ATTR):
147117
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module))
@@ -153,11 +123,7 @@ def initialize_hooked_attention(
153123
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
154124
model.config._attn_implementation = HOOKED_ATTENTION_NAME
155125

156-
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
157-
if quantize:
158-
impl.initialize_qparams_once(model, module)
159-
160-
initialize_hooked_kv_cache(model, module, quantize=quantize)
126+
initialize_hooked_kv_cache(model, module)
161127

162128

163129
# ----- hooks ----- #

src/compressed_tensors/modeling/kvcache.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from typing import Callable, Optional, Tuple
1717
from weakref import ref
1818

19-
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
20-
from compressed_tensors.quantization.lifecycle.initialize import (
21-
_initialize_scale_zero_point,
22-
)
19+
# from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
20+
# from compressed_tensors.quantization.lifecycle.initialize import (
21+
# _initialize_scale_zero_point,
22+
# )
2323
from compressed_tensors.utils import getattr_chain
2424
from compressed_tensors.utils.internal import InternalModule
2525
from torch import Tensor
@@ -33,6 +33,7 @@
3333
"initialize_hooked_kv_cache",
3434
"register_key_hook",
3535
"register_value_hook",
36+
"KV_CACHE_ATTR",
3637
]
3738

3839

@@ -88,25 +89,6 @@ def forward(
8889
self.past_key_values = None
8990
return ret
9091

91-
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
92-
"""
93-
Initialize kv cache quantization parameters if they have not already been
94-
initialized
95-
96-
:param model: parent model of attention module
97-
:param module: attention module to initialize with
98-
"""
99-
# TODO: move to initialize.py
100-
assert module is self.attn_module()
101-
scheme = getattr(module, "quantization_scheme", None)
102-
quant_args = getattr(scheme, "input_activations", None)
103-
104-
if not self._qparams_initialized and quant_args is not None:
105-
assert quant_args.strategy == QuantizationStrategy.TENSOR
106-
_initialize_scale_zero_point(module, "k", quant_args)
107-
_initialize_scale_zero_point(module, "v", quant_args)
108-
self._qparams_initialized = True
109-
11092

11193
# ----- initialize ----- #
11294

@@ -124,24 +106,17 @@ def _kv_cache_attention_hook(module: Module, args, kwargs):
124106
return args, kwargs
125107

126108

127-
def initialize_hooked_kv_cache(
128-
model: PreTrainedModel, module: Module, quantize: bool = False
129-
):
109+
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
130110
"""
131111
Initialize a `QuantizedKVCache` instance attached to attention
132112
133113
:param model: parent model of attention module
134114
:param module: attention module to initialize with
135-
:param quantize: initialize kv cache quantization parameters
136115
"""
137116
if not hasattr(module, KV_CACHE_ATTR):
138117
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
139118
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
140119

141-
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
142-
if quantize:
143-
kv_cache.initialize_qparams_once(model, module)
144-
145120

146121
# ----- hooks ----- #
147122

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
from typing import Optional, Tuple, Union
1818

1919
import torch
20+
from compressed_tensors.modeling import (
21+
IMPL_ATTR,
22+
KV_CACHE_ATTR,
23+
QuantizedAttentionImpl,
24+
QuantizedKVCache,
25+
)
2026
from compressed_tensors.quantization import (
2127
FP8_E4M3_DATA,
2228
ActivationOrdering,
@@ -39,15 +45,18 @@
3945
from compressed_tensors.utils import (
4046
disable_hf_hook,
4147
get_execution_device,
48+
get_head_dim,
4249
register_offload_parameter,
4350
)
4451
from torch.nn import Module, Parameter
52+
from transformers import PretrainedConfig
4553

4654

4755
__all__ = [
4856
"initialize_module_for_quantization",
4957
"is_attention_module",
5058
"initialize_qparams",
59+
"initialize_attn_qparams",
5160
]
5261

5362

@@ -81,7 +90,7 @@ def initialize_module_for_quantization(
8190

8291
if is_attention_module(module):
8392
# quantized actions based on calltime status
84-
_initialize_attn_scales(module)
93+
initialize_attn_qparams(module, scheme, force_zero_point)
8594

8695
else:
8796
if not isinstance(module, torch.nn.Linear):
@@ -131,14 +140,14 @@ def initialize_module_for_quantization(
131140
force_zero_point=force_zero_point,
132141
)
133142

134-
module.quantization_scheme = scheme
135-
module.quantization_status = QuantizationStatus.INITIALIZED
136-
137143
with disable_hf_hook(module):
138144
# wrap forward call of module to perform
139145
# quantized actions based on calltime status
140146
wrap_module_forward_quantized(module, scheme)
141147

148+
module.quantization_scheme = scheme
149+
module.quantization_status = QuantizationStatus.INITIALIZED
150+
142151

143152
def is_attention_module(module: Module):
144153
return "attention" in module.__class__.__name__.lower() and (
@@ -276,23 +285,48 @@ def initialize_qparams(
276285
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
277286

278287

279-
def _initialize_attn_scales(module: Module) -> None:
280-
"""Initlaize k_scale, v_scale for self_attn"""
288+
def initialize_attn_qparams(
289+
module: Module, scheme: QuantizationScheme, force_zero_point: bool
290+
):
291+
"""Initlaize k_scale, v_scale for self_attn"""
281292

282-
expected_shape = 1 # per tensor
293+
impl: Optional[QuantizedAttentionImpl] = getattr(module, IMPL_ATTR, None)
294+
kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None)
283295

284-
param = next(module.parameters())
285-
scale_dtype = param.dtype
286-
device = param.device
296+
if impl is None and kv_cache is None:
297+
raise ValueError("Attention module has quantization scheme but no attached ")
287298

288-
init_scale = Parameter(
289-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
290-
requires_grad=False,
299+
config: PretrainedConfig = getattr(impl, "config", None) or getattr(
300+
kv_cache, "config", None
291301
)
292-
register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale)
302+
head_dim = get_head_dim(config)
303+
observed_shape = (head_dim,) # (batch_size, num_attention_heads, slen, head_dim)
304+
observed_dtype = next(module.parameters()).dtype
305+
306+
if impl is not None:
307+
initialize_qparams(
308+
module,
309+
"q",
310+
scheme.input_activations,
311+
observed_shape=observed_shape,
312+
observed_dtype=observed_dtype,
313+
force_zero_point=force_zero_point,
314+
)
293315

294-
init_scale = Parameter(
295-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
296-
requires_grad=False,
297-
)
298-
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
316+
if kv_cache is not None:
317+
initialize_qparams(
318+
module,
319+
"k",
320+
scheme.input_activations,
321+
observed_shape=observed_shape,
322+
observed_dtype=observed_dtype,
323+
force_zero_point=force_zero_point,
324+
)
325+
initialize_qparams(
326+
module,
327+
"v",
328+
scheme.input_activations,
329+
observed_shape=observed_shape,
330+
observed_dtype=observed_dtype,
331+
force_zero_point=force_zero_point,
332+
)

src/compressed_tensors/utils/helpers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import numpy
2222
import torch
23-
from transformers import AutoConfig
23+
from transformers import AutoConfig, PretrainedConfig
2424

2525

2626
T = TypeVar("T", bound="Callable") # used by `deprecated`
@@ -45,6 +45,7 @@
4545
"unpack_bitmasks",
4646
"patch_attr",
4747
"ParameterizedDefaultDict",
48+
"get_head_dim",
4849
]
4950

5051
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -396,3 +397,14 @@ def get(self, *args, factory_kwargs: Mapping = MappingProxyType({})) -> Any:
396397
"""
397398
with patch_attr(self, "_factory_kwargs", factory_kwargs):
398399
return self[args]
400+
401+
402+
def get_head_dim(config: PretrainedConfig) -> int:
403+
if hasattr(config, "head_dim"):
404+
return config.head_dim
405+
406+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
407+
return config.hidden_size // config.num_attention_heads
408+
409+
else:
410+
raise ValueError()

0 commit comments

Comments
 (0)