55from compressed_tensors .quantization import (
66 DynamicType ,
77 KVCacheScaleType ,
8+ QuantizationArgs ,
89 QuantizationScheme ,
910 QuantizationStatus ,
1011 QuantizationStrategy ,
1920from llmcompressor .observers import Observer
2021from llmcompressor .utils .helpers import getattr_chain
2122
22- DEFAULT_MAXSHRINK = 0.20
23- DEFAULT_PATIENCE = 5
24- DEFAULT_AVERAGING_CONSTANT = 0.01
25- DEFAULT_GRID = 100.0
26- DEFAULT_NORM = 2.4
27-
2823__all__ = [
2924 "initialize_observer" ,
3025 "update_weight_zp_scale" ,
@@ -54,31 +49,19 @@ def initialize_observer(
5449 :param base_name: str used to name the observer attribute
5550
5651 """
57-
58- arg_name = "weights" if base_name == "weight" else f"{ base_name } _activations"
59- quantization_scheme = getattr (module , "quantization_scheme" , None )
60- if not quantization_scheme :
61- # no quantization scheme nothing to do
62- return
63-
64- quantization_args = getattr (quantization_scheme , arg_name , None )
65- # dont need observers for dynamic
66- if quantization_args is not None and quantization_args .dynamic in (
67- False ,
68- DynamicType .LOCAL ,
69- ):
70- observer_kwargs = quantization_args .observer_kwargs or {}
52+ if base_name == "weight" :
53+ arg_name = "weights"
54+ elif base_name == "output" :
55+ arg_name = "output_activations"
56+ else : # input, q, k, v
57+ arg_name = "input_activations"
58+
59+ args : QuantizationArgs = getattr_chain (
60+ module , f"quantization_scheme.{ arg_name } " , None
61+ )
62+ if args is not None and args .dynamic is not True :
7163 observer = Observer .load_from_registry (
72- quantization_args .observer ,
73- quantization_args = quantization_args ,
74- averaging_constant = observer_kwargs .get (
75- "averaging_constant" , DEFAULT_AVERAGING_CONSTANT
76- ),
77- # used by mse observer only, will be ignored by minmax observer
78- maxshrink = observer_kwargs .get ("maxshrink" , DEFAULT_MAXSHRINK ),
79- patience = observer_kwargs .get ("patience" , DEFAULT_PATIENCE ),
80- grid = observer_kwargs .get ("grid" , DEFAULT_GRID ),
81- norm = observer_kwargs .get ("norm" , DEFAULT_NORM ),
64+ args .observer , base_name = base_name , args = args , module = module
8265 )
8366 module .register_module (f"{ base_name } _observer" , observer )
8467
@@ -100,36 +83,17 @@ def call_observer(
10083 base_name is "weight", then the module's weight tensor will be used
10184 """
10285 with align_module_device (module ):
103- if base_name == "weight" :
104- value = module .weight
105- g_idx = getattr (module , "weight_g_idx" , None )
106- elif value is not None :
107- g_idx = None
108- else :
109- raise ValueError (
110- "Must provide a value to observe if not using weight observer"
111- )
112-
113- observer = getattr (module , f"{ base_name } _observer" )
86+ value = module .weight if base_name == "weight" else value
87+ observer : Observer = getattr (module , f"{ base_name } _observer" )
11488
11589 if should_calculate_gparam :
116- global_scale = observer (
117- value ,
118- should_calculate_gparam = True ,
119- )
90+ global_scale = observer .get_global_scale (value )
12091 update_offload_parameter (module , f"{ base_name } _global_scale" , global_scale )
121- else :
122- global_scale = getattr (module , f"{ base_name } _global_scale" , None )
12392
12493 if should_calculate_qparams :
125- updated_scale , updated_zero_point = observer (
126- value , g_idx = g_idx , global_scale = global_scale
127- )
128- # register or update scale & zero_point parameters (supports block shapes)
129- scale_name = f"{ base_name } _scale"
130- zp_name = f"{ base_name } _zero_point"
131- update_offload_parameter (module , scale_name , updated_scale )
132- update_offload_parameter (module , zp_name , updated_zero_point )
94+ scale , zero_point = observer (value )
95+ update_offload_parameter (module , f"{ base_name } _scale" , scale )
96+ update_offload_parameter (module , f"{ base_name } _zero_point" , zero_point )
13397
13498
13599def update_weight_global_scale (module : Module ):
@@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
148112 should_calculate_gparam = True ,
149113 should_calculate_qparams = False ,
150114 )
151- module .weight_observer .reset ()
152115
153116
154117def update_weight_zp_scale (module : Module ):
0 commit comments