11import  contextlib 
2- from  functools  import  wraps 
2+ from  functools  import  partial ,  wraps 
33from  typing  import  Any , Callable , ClassVar , Optional , Set , Union 
44
55import  torch 
6+ from  compressed_tensors .modeling  import  (
7+     register_key_hook ,
8+     register_query_hook ,
9+     register_value_hook ,
10+ )
611from  loguru  import  logger 
712from  pydantic  import  BaseModel 
813from  torch .utils .hooks  import  RemovableHandle 
@@ -92,7 +97,7 @@ def wrapped_hook(*args, **kwargs):
9297
9398            return  hook (* args , ** kwargs )
9499
95-         register_function  =  getattr (target , f"register_ { hook_type } _hook" 
100+         register_function  =  self . _get_register_function (target , hook_type )
96101        handle  =  register_function (wrapped_hook , ** kwargs )
97102        self ._hooks .add (handle )
98103        logger .debug (f"{ self } { handle }  )
@@ -113,3 +118,15 @@ def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
113118            hook .remove ()
114119
115120        self ._hooks  -=  handles 
121+ 
122+     def  _get_register_function (
123+         self , target : torch .nn .Module , hook_type : str 
124+     ) ->  Callable :
125+         if  hook_type  ==  "query" :
126+             return  partial (register_query_hook , target )
127+         elif  hook_type  ==  "key" :
128+             return  partial (register_key_hook , target )
129+         elif  hook_type  ==  "value" :
130+             return  partial (register_value_hook , target )
131+         else :
132+             return  getattr (target , f"register_{ hook_type }  )
0 commit comments