Skip to content

Commit a0b83b4

Browse files
committed
add fp4 test
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 32879da commit a0b83b4

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

src/llmcompressor/observers/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
1111
from compressed_tensors.registry.registry import RegistryMixin
12+
from compressed_tensors.utils import patch_attr
1213

1314
from llmcompressor.observers.helpers import flatten_for_calibration
1415

@@ -82,9 +83,11 @@ def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter:
8283
:param observed: value being observed
8384
:return: calibrated global parameter
8485
"""
85-
observed = observed.reshape((1, 1, -1)) # per tensor reshape
86-
min_vals, max_vals = self.get_min_max(observed)
87-
return generate_gparam(min_vals, max_vals)
86+
# avoid updating running min/max for global scales
87+
with patch_attr(self, "min_vals", None), patch_attr(self, "max_vals", None):
88+
observed = observed.reshape((1, 1, -1)) # per tensor reshape
89+
min_vals, max_vals = self.get_min_max(observed)
90+
return generate_gparam(min_vals, max_vals)
8891

8992
def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
9093
if self.module is None:

tests/llmcompressor/observers/test_mse.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def test_mse_observer_symmetric_scale_range():
7373

7474

7575
def test_mse_fp4():
76-
tensor = torch.arange(24, dtype=torch.bfloat16).reshape((4, 6)) / 24
76+
module = torch.nn.Linear(6, 4)
77+
module.weight.data = torch.arange(24, dtype=torch.bfloat16).reshape((4, 6)) / 24
7778

7879
weights = QuantizationArgs(
7980
num_bits=4,
@@ -84,8 +85,15 @@ def test_mse_fp4():
8485
)
8586

8687
observer = weights.observer
87-
observer = Observer.load_from_registry(observer, base_name="weight", args=weights)
88-
scale, zero_point = observer(tensor)
88+
observer = Observer.load_from_registry(
89+
observer, base_name="weight", args=weights, module=module
90+
)
8991

90-
qdq_tensor = fake_quantize(tensor, scale, zero_point, weights)
91-
assert torch.nn.functional.mse_loss(qdq_tensor, tensor) <= 0.002
92+
global_scale = observer.get_global_scale(module.weight)
93+
module.weight_global_scale = global_scale
94+
scale, zero_point = observer(module.weight)
95+
96+
qdq_tensor = fake_quantize(
97+
module.weight, scale, zero_point, weights, global_scale=global_scale
98+
)
99+
assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.002

0 commit comments

Comments
 (0)