Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,20 @@ def cuda_kernel_profiler(kernel_pattern):
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("use_inference_mode", [True, False])
@pytest.mark.parametrize("x_rank", [2, 3])
@torch.no_grad()
@skip_if_rocm(
"ROCm float4 gemm require gfx950"
) # TODO(future): deploy gfx950 in ROCM CI
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool):
def test_inference_workflow_mx(
elem_dtype,
bias: bool,
compile: bool,
emulate: bool,
use_inference_mode: bool,
x_rank: int,
):
"""
Smoke test for inference compile
"""
Expand Down Expand Up @@ -112,8 +121,15 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
m_mx = torch.compile(m_mx, fullgraph=True)

x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
if x_rank == 3:
x = x.unsqueeze(0)

y_ref = m(x)
y_mx = m_mx(x)
if use_inference_mode:
with torch.inference_mode():
y_mx = m_mx(x)
else:
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
assert sqnr >= SQNR_THRESHOLD, (
Expand Down
21 changes: 21 additions & 0 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,27 @@ def mx_addmm(func, types, args, kwargs):
return _addmm_mx_dispatch(a, b, func, bias=bias)


@implements([aten.linear.default])
def mx_linear(func, types, args, kwargs):
assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor)
a = args[0]

# make a 2d
orig_a_shape = a.shape
a_2d = a.view(-1, orig_a_shape[-1])

b = args[1].t()
if len(args) > 2:
bias = args[2]
res = _addmm_mx_dispatch(a_2d, b, aten.addmm.default, bias)
else:
res = _addmm_mx_dispatch(a_2d, b, aten.mm.default)

# reshape back to original shape
res = res.view(*orig_a_shape[:-1], res.shape[-1])
return res


@implements([aten.t.default])
def mx_t(func, types, args, kwargs):
# For now, only transpose(input, 0, 1) is supported.
Expand Down
Loading