Skip to content

Commit 0674268

Browse files
committed
address nits
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dc43b64 commit 0674268

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

src/compressed_tensors/modeling/kvcache.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
from typing import Callable, Optional, Tuple
17-
from weakref import ref
16+
from typing import Any, Callable, Dict, List, Optional, Tuple
17+
from weakref import ReferenceType, ref
1818

1919
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
2020
from compressed_tensors.utils import getattr_chain
@@ -55,7 +55,7 @@ def __init__(self, config: PretrainedConfig, attn_module: Module):
5555
super().__init__()
5656
self.config = config
5757
self.attn_module = ref(attn_module) # avoid circular reference
58-
self.past_key_values: Optional[Cache] = None
58+
self.past_key_values: Optional[ReferenceType[Cache]] = None
5959

6060
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
6161
return self(*args, **kwargs)
@@ -78,26 +78,46 @@ def forward(
7878

7979
# original cache
8080
if self.past_key_values is not None:
81-
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
81+
ret = self.past_key_values().update(
82+
key_states, value_states, *args, **kwargs
83+
)
8284
else:
8385
ret = (key_states, value_states)
84-
8586
self.past_key_values = None
87+
8688
return ret
8789

90+
def add_past_key_values(self, past_key_values: Optional[Cache]):
91+
if past_key_values is not None:
92+
self.past_key_values = ref(past_key_values)
93+
else:
94+
self.past_key_values = None
95+
8896

8997
# ----- initialize ----- #
9098

9199

92-
def _kv_cache_attention_hook(module: Module, args, kwargs):
93-
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
100+
def _kv_cache_attention_hook(
101+
module: Module, args: List[Any], kwargs: Dict[str, Any]
102+
) -> Tuple[List[Any], Dict[str, Any]]:
103+
"""
104+
Hook which should be called before each quantized attention forward pass.
105+
This hook dynamically replaces the `past_key_values` kwarg to the attention
106+
forward function.
107+
108+
The original kvcache object is assigned to QuantizedKVCache().past_key_values
109+
as a weakref to maintain original cache functionality and compute savings
110+
"""
94111
_past_kv_name = (
95112
"past_key_values" # transformers#39956
96113
if "past_key_values" in inspect.signature(module.forward).parameters
97114
else "past_key_value"
98115
)
99-
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
100-
kwargs[_past_kv_name] = kv_cache
116+
past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
117+
118+
cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
119+
cache.add_past_key_values(past_key_values)
120+
kwargs[_past_kv_name] = cache
101121

102122
return args, kwargs
103123

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def initialize_attn_qparams(
289289
kv_cache: Optional[QuantizedKVCache] = getattr(module, KV_CACHE_ATTR, None)
290290

291291
if impl is None and kv_cache is None:
292-
raise ValueError("Attention module has quantization scheme but no attached")
292+
raise ValueError(
293+
f"Attention module has quantization scheme but no {IMPL_ATTR} "
294+
f"or {KV_CACHE_ATTR} attributes. Please ensure that these "
295+
"attributes are initialized using `apply_quantization_config`."
296+
)
293297

294298
_validate_attention_scheme(scheme)
295299

@@ -337,7 +341,7 @@ def _validate_attention_scheme(scheme: QuantizationScheme):
337341
if scheme.weights is not None:
338342
raise ValueError(
339343
"Cannot apply weight quantization to attention. "
340-
"Instead, target (q|k|v)_proj"
344+
"Instead, target the (q|k|v)_proj submodule layers of attention"
341345
)
342346

343347
if scheme.input_activations is None:

0 commit comments

Comments
 (0)