Skip to content

Commit 41a0778

Browse files
authored
create a separate test for mx and nv serialization (#3251)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 03c2d28 commit 41a0778

File tree

3 files changed

+75
-21
lines changed

3 files changed

+75
-21
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8-
import tempfile
98
from contextlib import contextmanager
109

1110
import pytest
@@ -136,16 +135,6 @@ def test_inference_workflow_mx(
136135
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
137136
)
138137

139-
# serialization
140-
with tempfile.NamedTemporaryFile() as f:
141-
torch.save(m_mx.state_dict(), f)
142-
f.seek(0)
143-
144-
# temporary workaround for https://github.com/pytorch/ao/issues/3077
145-
torch.serialization.add_safe_globals([getattr])
146-
147-
_ = torch.load(f, weights_only=True)
148-
149138

150139
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
151140
@pytest.mark.skipif(
@@ -254,16 +243,6 @@ def test_inference_workflow_nvfp4(
254243
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
255244
)
256245

257-
# serialization
258-
with tempfile.NamedTemporaryFile() as f:
259-
torch.save(m_mx.state_dict(), f)
260-
f.seek(0)
261-
262-
# temporary workaround for https://github.com/pytorch/ao/issues/3077
263-
torch.serialization.add_safe_globals([getattr])
264-
265-
_ = torch.load(f, weights_only=True)
266-
267246

268247
class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
269248
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import subprocess
9+
import tempfile
10+
11+
import pytest
12+
import torch
13+
import torch.nn as nn
14+
15+
from torchao.prototype.mx_formats.config import (
16+
MXGemmKernelChoice,
17+
)
18+
from torchao.prototype.mx_formats.inference_workflow import (
19+
MXFPInferenceConfig,
20+
NVFP4InferenceConfig,
21+
NVFP4MMConfig,
22+
)
23+
from torchao.quantization import quantize_
24+
from torchao.utils import (
25+
is_sm_at_least_100,
26+
torch_version_at_least,
27+
)
28+
29+
if not torch_version_at_least("2.8.0"):
30+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
31+
32+
33+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
34+
@pytest.mark.skipif(not is_sm_at_least_100(), reason="needs CUDA capability 10.0+")
35+
@pytest.mark.parametrize("recipe_name", ["mxfp8", "nvfp4"])
36+
def test_serialization(recipe_name):
37+
"""
38+
Ensure that only `import torchao.prototype.mx_formats` is needed to load MX
39+
and NV checkpoints.
40+
"""
41+
42+
m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
43+
fname = None
44+
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
45+
if recipe_name == "mxfp8":
46+
config = MXFPInferenceConfig(
47+
activation_dtype=torch.float8_e4m3fn,
48+
weight_dtype=torch.float8_e4m3fn,
49+
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
50+
)
51+
else:
52+
assert recipe_name == "nvfp4", "unsupported"
53+
config = NVFP4InferenceConfig(
54+
mm_config=NVFP4MMConfig.DYNAMIC,
55+
use_triton_kernel=False,
56+
use_dynamic_per_tensor_scale=False,
57+
)
58+
59+
quantize_(m, config=config)
60+
torch.save(m.state_dict(), f.name)
61+
fname = f.name
62+
63+
assert fname is not None
64+
65+
code = f"""
66+
import torch
67+
import torchao.prototype.mx_formats
68+
_ = torch.load('{fname}', weights_only=True)
69+
"""
70+
71+
subprocess_out = subprocess.run(["python"], input=code, text=True)
72+
os.remove(fname)
73+
assert subprocess_out.returncode == 0, "failed weights-only load"

torchao/prototype/mx_formats/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Below is a toy training loop. For an example real training loop, see our torchti
7373
```python
7474
import torch
7575
from torchao.quantization import quantize_
76+
import torchao.prototype.mx_formats
7677
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode
7778

7879
# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
@@ -105,6 +106,7 @@ import copy
105106
import torch
106107
import torch.nn as nn
107108
from torchao.quantization import quantize_
109+
import torchao.prototype.mx_formats
108110
from torchao.prototype.mx_formats.config import (
109111
MXGemmKernelChoice,
110112
)

0 commit comments

Comments
 (0)