Skip to content

Commit e88e7d4

Browse files
authored
[Transform] Attention/Cache transforms (#436)
* attention quant Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * address nits Signed-off-by: Kyle Sayers <[email protected]> * fix kv cache serialization, add tests Signed-off-by: Kyle Sayers <[email protected]> * fix style Signed-off-by: Kyle Sayers <[email protected]> * do not force zp for attention Signed-off-by: Kyle Sayers <[email protected]> * populate ALL_MASK_ATTENTION_FUNCTIONS Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 32016cb commit e88e7d4

File tree

12 files changed

+919
-210
lines changed

12 files changed

+919
-210
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
# isort: off
17+
from .kvcache import *
18+
from .attention import *
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional
17+
18+
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
19+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20+
from compressed_tensors.utils import getattr_chain
21+
from compressed_tensors.utils.internal import InternalModule
22+
from torch import Tensor
23+
from torch.nn import Module
24+
from torch.utils.hooks import RemovableHandle
25+
from transformers import PretrainedConfig, PreTrainedModel
26+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
27+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
28+
29+
30+
__all__ = [
31+
"QuantizedAttentionImpl",
32+
"initialize_hooked_attention",
33+
"register_query_hook",
34+
"IMPL_ATTR",
35+
]
36+
37+
38+
IMPL_ATTR = "impl"
39+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
40+
41+
42+
class QuantizedAttentionImpl(InternalModule):
43+
"""
44+
QuantizedAttentionImpl module which wraps the functionality of the original
45+
attention implementation. Unlike the original attention function, this
46+
implementation is a `torch.nn.Module` which can be hooked to trigger
47+
transforms and calibration hooks.
48+
49+
This module works by being registered as a submodule to attention modules via
50+
`initialize_hooked_attention`, registering a new attention implementation function
51+
which calls this module, then setting the model attention implementation to the new
52+
function. After triggering hooks and quantization, this module calls the original
53+
attention implementation function.
54+
"""
55+
56+
_original_impl = "eager"
57+
58+
def __init__(self, config: PretrainedConfig):
59+
super().__init__()
60+
self.config = config
61+
62+
def forward(
63+
self,
64+
module: Module,
65+
query: Tensor,
66+
key: Tensor,
67+
value: Tensor,
68+
*args,
69+
**kwargs,
70+
):
71+
# quantization
72+
quant_args_attr = "quantization_scheme.input_activations"
73+
quant_args = getattr_chain(module, quant_args_attr, None)
74+
quant_enabled = getattr(module, "quantization_enabled", True)
75+
if quant_args is not None and quant_enabled:
76+
query = forward_quantize(module, query, "q", quant_args)
77+
78+
# original attention
79+
return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl](
80+
module,
81+
query,
82+
key,
83+
value,
84+
*args,
85+
**kwargs,
86+
)
87+
88+
89+
# ----- initialize ----- #
90+
91+
92+
def _hooked_attention(module: Module, *args, **kwargs):
93+
assert hasattr(module, IMPL_ATTR), (
94+
f"Using {HOOKED_ATTENTION_NAME} attention implementation, "
95+
f"but attention module does not have {IMPL_ATTR} submodule."
96+
)
97+
98+
return getattr(module, IMPL_ATTR)(module, *args, **kwargs)
99+
100+
101+
def initialize_hooked_attention(model: PreTrainedModel, module: Module):
102+
"""
103+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
104+
attached to attention. Assumes that only one model is hooked at a time.
105+
106+
:param model: parent model of attention module
107+
:param module: attention module to initialize with
108+
"""
109+
if not hasattr(module, IMPL_ATTR):
110+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config))
111+
112+
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
113+
QuantizedAttentionImpl._original_impl = model.config._attn_implementation
114+
original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation]
115+
116+
ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention)
117+
ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask)
118+
model.set_attn_implementation(HOOKED_ATTENTION_NAME)
119+
assert model.config._attn_implementation == HOOKED_ATTENTION_NAME
120+
121+
initialize_hooked_kv_cache(model, module)
122+
123+
124+
# ----- hooks ----- #
125+
126+
127+
def register_query_hook(
128+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
129+
) -> RemovableHandle:
130+
"""
131+
Register a hook which takes post-rope query states as an argument and
132+
returns the modified query states or `None`
133+
134+
:param module: attention module to add hook to
135+
:param hook: query hook function
136+
"""
137+
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
138+
139+
def _hook(impl: QuantizedAttentionImpl, args, kwargs):
140+
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
141+
value = hook(module, bound.arguments["query"])
142+
if value is not None:
143+
bound.arguments["query"] = value
144+
145+
return bound.args, bound.kwargs
146+
147+
return impl.register_forward_pre_hook(_hook, with_kwargs=True)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
17+
from weakref import ReferenceType, ref
18+
19+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20+
from compressed_tensors.utils import getattr_chain
21+
from compressed_tensors.utils.internal import InternalModule
22+
from torch import Tensor
23+
from torch.nn import Module
24+
from torch.utils.hooks import RemovableHandle
25+
from transformers import Cache, PretrainedConfig, PreTrainedModel
26+
27+
28+
__all__ = [
29+
"QuantizedKVCache",
30+
"initialize_hooked_kv_cache",
31+
"register_key_hook",
32+
"register_value_hook",
33+
"KV_CACHE_ATTR",
34+
]
35+
36+
37+
KV_CACHE_ATTR = "kv_cache"
38+
39+
40+
class QuantizedKVCache(InternalModule):
41+
"""
42+
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
43+
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
44+
hooked to trigger transforms and calibration hooks.
45+
46+
This module works by being registered as a submodule to attention modules via
47+
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
48+
kwargs with this module. This module adopts the functionality of the replaced cache,
49+
preserving caching functionality such as sliding window attention, ect.
50+
51+
:param attn_module: parent attention module
52+
"""
53+
54+
def __init__(self, config: PretrainedConfig, attn_module: Module):
55+
super().__init__()
56+
self.config = config
57+
self.attn_module = ref(attn_module) # avoid circular reference
58+
self.past_key_values: Optional[ReferenceType[Cache]] = None
59+
60+
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
61+
return self(*args, **kwargs)
62+
63+
def forward(
64+
self,
65+
key_states: Tensor,
66+
value_states: Tensor,
67+
*args,
68+
**kwargs,
69+
) -> Tuple[Tensor, Tensor]:
70+
# quantization
71+
module = self.attn_module()
72+
quant_args_attr = "quantization_scheme.input_activations"
73+
quant_args = getattr_chain(module, quant_args_attr, None)
74+
quant_enabled = getattr(module, "quantization_enabled", True)
75+
if quant_args is not None and quant_enabled:
76+
key_states = forward_quantize(module, key_states, "k", quant_args)
77+
value_states = forward_quantize(module, value_states, "v", quant_args)
78+
79+
# original cache
80+
if self.past_key_values is not None:
81+
ret = self.past_key_values().update(
82+
key_states, value_states, *args, **kwargs
83+
)
84+
else:
85+
ret = (key_states, value_states)
86+
self.past_key_values = None
87+
88+
return ret
89+
90+
def add_past_key_values(self, past_key_values: Optional[Cache]):
91+
if past_key_values is not None:
92+
self.past_key_values = ref(past_key_values)
93+
else:
94+
self.past_key_values = None
95+
96+
97+
# ----- initialize ----- #
98+
99+
100+
def _kv_cache_attention_hook(
101+
module: Module, args: List[Any], kwargs: Dict[str, Any]
102+
) -> Tuple[List[Any], Dict[str, Any]]:
103+
"""
104+
Hook which should be called before each quantized attention forward pass.
105+
This hook dynamically replaces the `past_key_values` kwarg to the attention
106+
forward function.
107+
108+
The original kvcache object is assigned to QuantizedKVCache().past_key_values
109+
as a weakref to maintain original cache functionality and compute savings
110+
"""
111+
_past_kv_name = (
112+
"past_key_values" # transformers#39956
113+
if "past_key_values" in inspect.signature(module.forward).parameters
114+
else "past_key_value"
115+
)
116+
past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
117+
118+
cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
119+
cache.add_past_key_values(past_key_values)
120+
kwargs[_past_kv_name] = cache
121+
122+
return args, kwargs
123+
124+
125+
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
126+
"""
127+
Initialize a `QuantizedKVCache` instance attached to attention
128+
129+
:param model: parent model of attention module
130+
:param module: attention module to initialize with
131+
"""
132+
if not hasattr(module, KV_CACHE_ATTR):
133+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
134+
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
135+
136+
137+
# ----- hooks ----- #
138+
139+
140+
def register_key_hook(
141+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
142+
) -> RemovableHandle:
143+
"""
144+
Register a hook which takes post-rope key states as an argument and
145+
returns the modified key states or `None`
146+
147+
:param module: attention module to add hook to
148+
:param hook: key hook function
149+
"""
150+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
151+
152+
def _hook(cache: QuantizedKVCache, args, kwargs):
153+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
154+
value = hook(module, bound.arguments["key_states"])
155+
if value is not None:
156+
bound.arguments["key_states"] = value
157+
158+
return bound.args, bound.kwargs
159+
160+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
161+
162+
163+
def register_value_hook(
164+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
165+
) -> RemovableHandle:
166+
"""
167+
Register a hook which takes value states as an argument and
168+
returns the modified value states or `None`
169+
170+
:param module: attention module to add hook to
171+
:param hook: value hook function
172+
"""
173+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
174+
175+
def _hook(cache: QuantizedKVCache, args, kwargs):
176+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
177+
value = hook(module, bound.arguments["value_states"])
178+
if value is not None:
179+
bound.arguments["value_states"] = value
180+
181+
return bound.args, bound.kwargs
182+
183+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)

0 commit comments

Comments
 (0)