You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Observers] Refactor for better FP4 support, static and memoryless observers (#1903)
* FP4
* Fix bug discovered
[here](#1830 (comment))
where dynamic="local" nvfp4 calculations would increment the observer
twice as fast as normal
* Enable MSE observer to be used with FP4
```psuedocode
mse_quant_error := mean((x - fake_quant(x))**2)
global_scale <- min[min_vals, max_vals,
global_scale](mse_quant_error(x))
scale, zp <- min[min_vals, max_vals](mse_quant_error(x, global_scale))
```
* Simplification
* Make supporting attention calibration easier by separating out
weight/activation/attention reshaping
* Improve readability of observer codes by removing many levels of
function indirection
* Drop support for calibration with non-divisible group sizes. This is
not really a loss, since [forward
passes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/lifecycle/forward.py#L279)
also make this assumption
* New observers
* `memoryless_minmax` computes min and max values on the fly in a
dynamic-quantization style. This observer is useful for PTQ weight
quantization
* `static_minmax` computes absolute min and max values across all
observations. This observer is useful for PTQ activation quantization
* `memoryless_mse` computes best qparams w.r.t. MSE loss for each
observation. This observer is useful for PTQ weight quantization
* Memory improvements
* All observers no longer store copies of scales and zero points,
reducing the amount of required memory
* Newly introduced "memoryless" observers do not store any quantization
parameters, which greatly reduces the memory requirements for PTQ weight
quantization of very large models
| Diagrams |
| - |
| Before |
| <img width="886" height="595" alt="before"
src="https://github.com/user-attachments/assets/660d94c2-3ac8-4e05-9e9b-53d21145abac"
/> |
| After |
<img width="1527" height="595" alt="after"
src="https://github.com/user-attachments/assets/51a0107e-3fbd-413c-a7a6-03ddc3612169"
/> |
* Standardize reshaping using `flatten_for_calibration`
* This function reshapes all observed values to `(num_observations,
*qparams_shape, group_size)`
* This function the complexity associated with passing "reduce dims" and
trying to handle weights, activations, and attention states all in the
same function
* In the future, this function could be applied to the quantization
forward pass, although there's probably no need to outside of
standardization
* Implement `get_global_scale` on `Observer` base
* This function decouples minmax calculations from regular qparam
calculations (avoiding the double increment bug)
* This function enables the MSE observer to be used with FP4 global
scales
* Added additional minmax tests which check exact values of scales. This
test passes both on main and this branch, demonstrating that minmax
observer behavior remains unchanged
* Added additional MSE tests which check exact values of mse losses.
This test passes both on main and this branch, demonstrating that MSE
observer behavior remains unchanged
* Added FP4 MSE test
```
nvfp4-static-minmax
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6167|± | N/A|
```
```
nvfp4-minmax
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val| 0|none | 0|mmmu_acc|↑ |0.6011|± | N/A|
```
---------
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Dan Huang <[email protected]>
Co-authored-by: dhuangnm <[email protected]>
0 commit comments