1313# limitations under the License.
1414
1515import inspect
16- from typing import Callable , Optional , Tuple
17- from weakref import ref
16+ from typing import Any , Callable , Dict , List , Optional , Tuple
17+ from weakref import ReferenceType , ref
1818
1919from compressed_tensors .quantization .lifecycle .forward import forward_quantize
2020from compressed_tensors .utils import getattr_chain
@@ -55,7 +55,7 @@ def __init__(self, config: PretrainedConfig, attn_module: Module):
5555 super ().__init__ ()
5656 self .config = config
5757 self .attn_module = ref (attn_module ) # avoid circular reference
58- self .past_key_values : Optional [Cache ] = None
58+ self .past_key_values : Optional [ReferenceType [ Cache ] ] = None
5959
6060 def update (self , * args , ** kwargs ) -> Tuple [Tensor , Tensor ]:
6161 return self (* args , ** kwargs )
@@ -78,26 +78,46 @@ def forward(
7878
7979 # original cache
8080 if self .past_key_values is not None :
81- ret = self .past_key_values .update (key_states , value_states , * args , ** kwargs )
81+ ret = self .past_key_values ().update (
82+ key_states , value_states , * args , ** kwargs
83+ )
8284 else :
8385 ret = (key_states , value_states )
84-
8586 self .past_key_values = None
87+
8688 return ret
8789
90+ def add_past_key_values (self , past_key_values : Optional [Cache ]):
91+ if past_key_values is not None :
92+ self .past_key_values = ref (past_key_values )
93+ else :
94+ self .past_key_values = None
95+
8896
8997# ----- initialize ----- #
9098
9199
92- def _kv_cache_attention_hook (module : Module , args , kwargs ):
93- kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
100+ def _kv_cache_attention_hook (
101+ module : Module , args : List [Any ], kwargs : Dict [str , Any ]
102+ ) -> Tuple [List [Any ], Dict [str , Any ]]:
103+ """
104+ Hook which should be called before each quantized attention forward pass.
105+ This hook dynamically replaces the `past_key_values` kwarg to the attention
106+ forward function.
107+
108+ The original kvcache object is assigned to QuantizedKVCache().past_key_values
109+ as a weakref to maintain original cache functionality and compute savings
110+ """
94111 _past_kv_name = (
95112 "past_key_values" # transformers#39956
96113 if "past_key_values" in inspect .signature (module .forward ).parameters
97114 else "past_key_value"
98115 )
99- kv_cache .past_key_values = kwargs .get (_past_kv_name , None )
100- kwargs [_past_kv_name ] = kv_cache
116+ past_key_values : Optional [Cache ] = kwargs .get (_past_kv_name , None )
117+
118+ cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
119+ cache .add_past_key_values (past_key_values )
120+ kwargs [_past_kv_name ] = cache
101121
102122 return args , kwargs
103123
0 commit comments