Skip to content

Commit 3939769

Browse files
kylesayrsdhuangnm
authored andcommitted
[Observers] Refactor for better FP4 support, static and memoryless observers (vllm-project#1903)
## Purpose ## * FP4 * Fix bug discovered [here](vllm-project#1830 (comment)) where dynamic="local" nvfp4 calculations would increment the observer twice as fast as normal * Enable MSE observer to be used with FP4 ```psuedocode mse_quant_error := mean((x - fake_quant(x))**2) global_scale <- min[min_vals, max_vals, global_scale](mse_quant_error(x)) scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale)) ``` * Simplification * Make supporting attention calibration easier by separating out weight/activation/attention reshaping * Improve readability of observer codes by removing many levels of function indirection * Drop support for calibration with non-divisible group sizes. This is not really a loss, since [forward passes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/lifecycle/forward.py#L279) also make this assumption * New observers * `memoryless_minmax` computes min and max values on the fly in a dynamic-quantization style. This observer is useful for PTQ weight quantization * `static_minmax` computes absolute min and max values across all observations. This observer is useful for PTQ activation quantization * `memoryless_mse` computes best qparams w.r.t. MSE loss for each observation. This observer is useful for PTQ weight quantization * Memory improvements * All observers no longer store copies of scales and zero points, reducing the amount of required memory * Newly introduced "memoryless" observers do not store any quantization parameters, which greatly reduces the memory requirements for PTQ weight quantization of very large models | Diagrams | | - | | Before | | <img width="886" height="595" alt="before" src="https://github.com/user-attachments/assets/660d94c2-3ac8-4e05-9e9b-53d21145abac" /> | | After | <img width="1527" height="595" alt="after" src="https://github.com/user-attachments/assets/51a0107e-3fbd-413c-a7a6-03ddc3612169" /> | ## Changes ## * Standardize reshaping using `flatten_for_calibration` * This function reshapes all observed values to `(num_observations, *qparams_shape, group_size)` * This function the complexity associated with passing "reduce dims" and trying to handle weights, activations, and attention states all in the same function * In the future, this function could be applied to the quantization forward pass, although there's probably no need to outside of standardization * Implement `get_global_scale` on `Observer` base * This function decouples minmax calculations from regular qparam calculations (avoiding the double increment bug) * This function enables the MSE observer to be used with FP4 global scales ## Testing ## * Added additional minmax tests which check exact values of scales. This test passes both on main and this branch, demonstrating that minmax observer behavior remains unchanged * Added additional MSE tests which check exact values of mse losses. This test passes both on main and this branch, demonstrating that MSE observer behavior remains unchanged * Added FP4 MSE test ## Evaluation ## ``` nvfp4-static-minmax | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|--------|---|-----:|---|------| |mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6167|± | N/A| ``` ``` nvfp4-minmax | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |--------|------:|------|-----:|--------|---|-----:|---|------| |mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6011|± | N/A| ``` --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Dan Huang <[email protected]> Co-authored-by: dhuangnm <[email protected]> Signed-off-by: ronantakizawa <[email protected]>
1 parent f51ef7f commit 3939769

File tree

17 files changed

+1219
-820
lines changed

17 files changed

+1219
-820
lines changed

docs/observers.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ from llmcompressor.observers import Observer
6565
from compressed_tensors.quantization.quant_args import QuantizationArgs
6666

6767
args = QuantizationArgs(num_bits=4, strategy="group", group_size=128)
68-
observer = Observer.load_from_registry("minmax", quantization_args=args)
68+
observer = Observer.load_from_registry(
69+
"minmax",
70+
base_name="weight",
71+
quantization_args=args,
72+
)
6973

7074
x = torch.randn(64, 512)
7175
scale, zero_point = observer(x)

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,15 @@ def update(
8686
"""
8787

8888
if len(self.k_observers) <= layer_idx:
89-
k_observer_name = self.quantization_args.observer
9089
k_observer = Observer.load_from_registry(
91-
k_observer_name, quantization_args=self.quantization_args
90+
self.quantization_args.observer,
91+
base_name="k",
92+
args=self.quantization_args,
9293
)
93-
v_observer_name = self.quantization_args.observer
9494
v_observer = Observer.load_from_registry(
95-
v_observer_name, quantization_args=self.quantization_args
95+
self.quantization_args.observer,
96+
base_name="v",
97+
args=self.quantization_args,
9698
)
9799

98100
# NOTE: User may ignore some layers in configuration,

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from compressed_tensors.quantization import (
66
DynamicType,
77
KVCacheScaleType,
8+
QuantizationArgs,
89
QuantizationScheme,
910
QuantizationStatus,
1011
QuantizationStrategy,
@@ -19,12 +20,6 @@
1920
from llmcompressor.observers import Observer
2021
from llmcompressor.utils.helpers import getattr_chain
2122

22-
DEFAULT_MAXSHRINK = 0.20
23-
DEFAULT_PATIENCE = 5
24-
DEFAULT_AVERAGING_CONSTANT = 0.01
25-
DEFAULT_GRID = 100.0
26-
DEFAULT_NORM = 2.4
27-
2823
__all__ = [
2924
"initialize_observer",
3025
"update_weight_zp_scale",
@@ -54,31 +49,19 @@ def initialize_observer(
5449
:param base_name: str used to name the observer attribute
5550
5651
"""
57-
58-
arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
59-
quantization_scheme = getattr(module, "quantization_scheme", None)
60-
if not quantization_scheme:
61-
# no quantization scheme nothing to do
62-
return
63-
64-
quantization_args = getattr(quantization_scheme, arg_name, None)
65-
# dont need observers for dynamic
66-
if quantization_args is not None and quantization_args.dynamic in (
67-
False,
68-
DynamicType.LOCAL,
69-
):
70-
observer_kwargs = quantization_args.observer_kwargs or {}
52+
if base_name == "weight":
53+
arg_name = "weights"
54+
elif base_name == "output":
55+
arg_name = "output_activations"
56+
else: # input, q, k, v
57+
arg_name = "input_activations"
58+
59+
args: QuantizationArgs = getattr_chain(
60+
module, f"quantization_scheme.{arg_name}", None
61+
)
62+
if args is not None and args.dynamic is not True:
7163
observer = Observer.load_from_registry(
72-
quantization_args.observer,
73-
quantization_args=quantization_args,
74-
averaging_constant=observer_kwargs.get(
75-
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
76-
),
77-
# used by mse observer only, will be ignored by minmax observer
78-
maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK),
79-
patience=observer_kwargs.get("patience", DEFAULT_PATIENCE),
80-
grid=observer_kwargs.get("grid", DEFAULT_GRID),
81-
norm=observer_kwargs.get("norm", DEFAULT_NORM),
64+
args.observer, base_name=base_name, args=args, module=module
8265
)
8366
module.register_module(f"{base_name}_observer", observer)
8467

@@ -100,36 +83,17 @@ def call_observer(
10083
base_name is "weight", then the module's weight tensor will be used
10184
"""
10285
with align_module_device(module):
103-
if base_name == "weight":
104-
value = module.weight
105-
g_idx = getattr(module, "weight_g_idx", None)
106-
elif value is not None:
107-
g_idx = None
108-
else:
109-
raise ValueError(
110-
"Must provide a value to observe if not using weight observer"
111-
)
112-
113-
observer = getattr(module, f"{base_name}_observer")
86+
value = module.weight if base_name == "weight" else value
87+
observer: Observer = getattr(module, f"{base_name}_observer")
11488

11589
if should_calculate_gparam:
116-
global_scale = observer(
117-
value,
118-
should_calculate_gparam=True,
119-
)
90+
global_scale = observer.get_global_scale(value)
12091
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
121-
else:
122-
global_scale = getattr(module, f"{base_name}_global_scale", None)
12392

12493
if should_calculate_qparams:
125-
updated_scale, updated_zero_point = observer(
126-
value, g_idx=g_idx, global_scale=global_scale
127-
)
128-
# register or update scale & zero_point parameters (supports block shapes)
129-
scale_name = f"{base_name}_scale"
130-
zp_name = f"{base_name}_zero_point"
131-
update_offload_parameter(module, scale_name, updated_scale)
132-
update_offload_parameter(module, zp_name, updated_zero_point)
94+
scale, zero_point = observer(value)
95+
update_offload_parameter(module, f"{base_name}_scale", scale)
96+
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
13397

13498

13599
def update_weight_global_scale(module: Module):
@@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module):
148112
should_calculate_gparam=True,
149113
should_calculate_qparams=False,
150114
)
151-
module.weight_observer.reset()
152115

153116

154117
def update_weight_zp_scale(module: Module):

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
QuantizationStrategy,
1111
fake_quantize,
1212
)
13+
from compressed_tensors.utils import update_offload_parameter
1314
from loguru import logger
1415

1516
from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
@@ -95,9 +96,10 @@ def quantize_weight(
9596

9697
# create observer for calculating quantization parameters
9798
observer = Observer.load_from_registry(
98-
quant_args.observer,
99-
quantization_args=quant_args,
100-
averaging_constant=1.0, # ignore moving average
99+
"memoryless_minmax",
100+
base_name="weight",
101+
args=quant_args,
102+
module=module,
101103
)
102104

103105
# standardize shape and dtype
@@ -119,22 +121,23 @@ def quantize_weight(
119121
if actorder == ActivationOrdering.GROUP:
120122
# permute by activation order first, then update groups
121123
W, H, perm = _apply_activation_ordering(W, H)
122-
scale, zero_point = observer(W, g_idx=None)
124+
update_offload_parameter(module, "weight_g_idx", g_idx)
125+
scale, zero_point = observer(W)
123126

124127
# use identity g_idx (invert permutation later)
125128

126129
elif actorder == ActivationOrdering.WEIGHT:
127130
# update groups first, then permute by activation order
128-
scale, zero_point = observer(W, g_idx=None)
131+
scale, zero_point = observer(W)
129132
W, H, perm = _apply_activation_ordering(W, H)
130133

131134
# permute g_idx to maintain identity mapping after unpermutation
132135
g_idx = g_idx[perm]
133136

134137
else:
135-
scale, zero_point = observer(W, g_idx=None)
138+
scale, zero_point = observer(W)
136139
else:
137-
scale, zero_point = observer(W, g_idx=None)
140+
scale, zero_point = observer(W)
138141

139142
# sparsity mask
140143
sparsity = tensor_sparsity(W)

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111

1212
from .helpers import *
1313
from .base import *
14+
from .moving_base import *
1415
from .min_max import *
1516
from .mse import *

0 commit comments

Comments
 (0)