1515import inspect
1616from typing import Callable , Optional , Tuple
1717
18- import torch
19- import transformers
2018from compressed_tensors .quantization import QuantizationStrategy , forward_quantize
2119from compressed_tensors .quantization .lifecycle .initialize import (
2220 _initialize_scale_zero_point ,
2321)
2422from compressed_tensors .utils import getattr_chain
2523from compressed_tensors .utils .internal import InternalModule
26- from packaging import version
2724from torch import Tensor
25+ from torch .nn import Module
2826from torch .utils .hooks import RemovableHandle
2927from transformers import Cache , PreTrainedModel
3028
3129
32- __all__ = ["KV_CACHE_ATTR" , "QuantizedKVCache" ]
30+ __all__ = [
31+ "QuantizedKVCache" ,
32+ "initialize_hooked_kv_cache" ,
33+ "register_key_hook" ,
34+ "register_value_hook" ,
35+ ]
3336
3437
3538KV_CACHE_ATTR = "kv_cache"
3639
3740
3841class QuantizedKVCache (InternalModule ):
39- def __init__ (self , attn_module : torch .nn .Module ):
42+ """
43+ QuantizedKVCache module which wraps the functionality of any existing kvcache args.
44+ Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
45+ hooked to trigger transforms and calibration hooks.
46+
47+ This module works by being registered as a submodule to attention modules via
48+ `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
49+ kwargs with this module. This module adopts the functionality of the replaced cache,
50+ preserving caching functionality such as sliding window attention, ect.
51+
52+ :param attn_module: parent attention module
53+ """
54+
55+ def __init__ (self , attn_module : Module ):
4056 super ().__init__ ()
41- self .attn_module_container = [attn_module ] # avoid nn.Module circular reference
57+ self .attn_module_container = [attn_module ] # avoid circular reference
4258 self .past_key_values : Optional [Cache ] = None
4359 self ._qparams_initialized = False
4460
@@ -70,13 +86,19 @@ def forward(
7086 self .past_key_values = None
7187 return ret
7288
73- def initialize_qparams_once (self , model : PreTrainedModel , module : torch .nn .Module ):
89+ def initialize_qparams_once (self , model : PreTrainedModel , module : Module ):
90+ """
91+ Initialize kv cache quantization parameters if they have not already been
92+ intialized
93+
94+ :param model: parent model of attention module
95+ :param module: attention module to initialize with
96+ """
7497 assert module is self .attn_module_container [0 ]
7598 scheme = getattr (module , "quantization_scheme" , None )
7699 quant_args = getattr (scheme , "input_activations" , None )
77100
78101 if not self ._qparams_initialized and quant_args is not None :
79- # TODO: use model.config.num_key_value_heads to find key_size, value_size
80102 assert quant_args .strategy == QuantizationStrategy .TENSOR
81103 _initialize_scale_zero_point (module , "k" , quant_args )
82104 _initialize_scale_zero_point (module , "v" , quant_args )
@@ -86,19 +108,7 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul
86108# ----- initialize ----- #
87109
88110
89- def initialize_hooked_kv_cache (
90- model : PreTrainedModel , module : torch .nn .Module , quantize : bool = False
91- ):
92- if not hasattr (module , KV_CACHE_ATTR ):
93- module .register_module (KV_CACHE_ATTR , QuantizedKVCache (module ))
94- module .register_forward_pre_hook (kv_cache_attention_hook , with_kwargs = True )
95-
96- kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
97- if quantize :
98- kv_cache .initialize_qparams_once (model , module )
99-
100-
101- def kv_cache_attention_hook (module : torch .nn .Module , args , kwargs ):
111+ def _kv_cache_attention_hook (module : Module , args , kwargs ):
102112 kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
103113 _past_kv_name = (
104114 "past_key_values" # transformers#39956
@@ -111,10 +121,38 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
111121 return args , kwargs
112122
113123
124+ def initialize_hooked_kv_cache (
125+ model : PreTrainedModel , module : Module , quantize : bool = False
126+ ):
127+ """
128+ Initialize a `QuantizedKVCache` instance attached to attention
129+
130+ :param model: parent model of attention module
131+ :param module: attention module to initialize with
132+ :param quantize: initialize kv cache quantization parameters
133+ """
134+ if not hasattr (module , KV_CACHE_ATTR ):
135+ module .register_module (KV_CACHE_ATTR , QuantizedKVCache (module ))
136+ module .register_forward_pre_hook (_kv_cache_attention_hook , with_kwargs = True )
137+
138+ kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
139+ if quantize :
140+ kv_cache .initialize_qparams_once (model , module )
141+
142+
114143# ----- hooks ----- #
115144
116145
117- def register_key_hook (module : torch .nn .Module , hook : Callable ) -> RemovableHandle :
146+ def register_key_hook (
147+ module : Module , hook : Callable [[Module , Tensor ], Optional [Tensor ]]
148+ ) -> RemovableHandle :
149+ """
150+ Register a hook which takes post-rope key states as an argument and
151+ returns the modified key states or `None`
152+
153+ :param module: attention module to add hook to
154+ :param hook: key hook function
155+ """
118156 kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
119157
120158 def _hook (cache : QuantizedKVCache , args , kwargs ):
@@ -128,7 +166,16 @@ def _hook(cache: QuantizedKVCache, args, kwargs):
128166 return kv_cache .register_forward_pre_hook (_hook , with_kwargs = True )
129167
130168
131- def register_value_hook (module : torch .nn .Module , hook : Callable ) -> RemovableHandle :
169+ def register_value_hook (
170+ module : Module , hook : Callable [[Module , Tensor ], Optional [Tensor ]]
171+ ) -> RemovableHandle :
172+ """
173+ Register a hook which takes value states as an argument and
174+ returns the modified value states or `None`
175+
176+ :param module: attention module to add hook to
177+ :param hook: value hook function
178+ """
132179 kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
133180
134181 def _hook (cache : QuantizedKVCache , args , kwargs ):
0 commit comments