Skip to content

Commit 32016cb

Browse files
authored
[MXFP4] Add MXFP4 Compressor (#502)
* add mxfp4_formaat * fix name
1 parent 413addd commit 32016cb

File tree

4 files changed

+12
-2
lines changed

4 files changed

+12
-2
lines changed

src/compressed_tensors/compressors/quantized_compressors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
# flake8: noqa
1515

1616
from .base import *
17+
from .fp4_quantized import *
1718
from .naive_quantized import *
18-
from .nvfp4_quantized import *
1919
from .pack_quantized import *

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py renamed to src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ def decompress_weight(
123123
return decompressed_weight
124124

125125

126+
@BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value)
127+
class MXFP4PackedCompressor(NVFP4PackedCompressor):
128+
"""
129+
Alias for mxfp4 quantized models
130+
"""
131+
132+
pass
133+
134+
126135
@torch.compile(fullgraph=True, dynamic=True)
127136
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
128137
"""

src/compressed_tensors/config/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CompressionFormat(Enum):
3434
marlin_24 = "marlin-24"
3535
mixed_precision = "mixed-precision"
3636
nvfp4_pack_quantized = "nvfp4-pack-quantized"
37+
mxfp4_pack_quantized = "mxfp4-pack-quantized"
3738

3839

3940
@unique

tests/test_compressors/quantized_compressors/test_nvfp4_quant.py renamed to tests/test_compressors/quantized_compressors/test_fp4_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest
1616
import torch
17-
from compressed_tensors.compressors.quantized_compressors.nvfp4_quantized import (
17+
from compressed_tensors.compressors.quantized_compressors.fp4_quantized import (
1818
pack_fp4_to_uint8,
1919
unpack_fp4_from_uint8,
2020
)

0 commit comments

Comments
 (0)