From f502a9efe3f3b28839bd2de2e7ed590c150dd632 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 23 Oct 2025 21:18:47 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 20 ++++++++++++++++-- torchao/prototype/mx_formats/mx_tensor.py | 21 +++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 556a3f8aff..51c066ad3f 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -71,11 +71,20 @@ def cuda_kernel_profiler(kernel_pattern): @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize("emulate", [True, False]) +@pytest.mark.parametrize("use_inference_mode", [True, False]) +@pytest.mark.parametrize("x_rank", [2, 3]) @torch.no_grad() @skip_if_rocm( "ROCm float4 gemm require gfx950" ) # TODO(future): deploy gfx950 in ROCM CI -def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool): +def test_inference_workflow_mx( + elem_dtype, + bias: bool, + compile: bool, + emulate: bool, + use_inference_mode: bool, + x_rank: int, +): """ Smoke test for inference compile """ @@ -112,8 +121,15 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b m_mx = torch.compile(m_mx, fullgraph=True) x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) + if x_rank == 3: + x = x.unsqueeze(0) + y_ref = m(x) - y_mx = m_mx(x) + if use_inference_mode: + with torch.inference_mode(): + y_mx = m_mx(x) + else: + y_mx = m_mx(x) sqnr = compute_error(y_ref, y_mx) SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0 assert sqnr >= SQNR_THRESHOLD, ( diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 3ad7d5f268..67aa9d767a 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -775,6 +775,27 @@ def mx_addmm(func, types, args, kwargs): return _addmm_mx_dispatch(a, b, func, bias=bias) +@implements([aten.linear.default]) +def mx_linear(func, types, args, kwargs): + assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor) + a = args[0] + + # make a 2d + orig_a_shape = a.shape + a_2d = a.view(-1, orig_a_shape[-1]) + + b = args[1].t() + if len(args) > 2: + bias = args[2] + res = _addmm_mx_dispatch(a_2d, b, aten.addmm.default, bias) + else: + res = _addmm_mx_dispatch(a_2d, b, aten.mm.default) + + # reshape back to original shape + res = res.view(*orig_a_shape[:-1], res.shape[-1]) + return res + + @implements([aten.t.default]) def mx_t(func, types, args, kwargs): # For now, only transpose(input, 0, 1) is supported.