Skip to content

Commit 37e8add

Browse files
committed
mxtensor: make scale shape match qdata
Summary: Given `qdata` shape M, K, makes `scale` shape be M, K // block_size This will be important to keep the logic sane as we add pre-swizzling in a future PR. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d8d4e78 ghstack-comment-id: 3413479441 Pull-Request: #3198
1 parent 0a9c5cb commit 37e8add

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
662662
rtol=0.0,
663663
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
664664
)
665+
666+
667+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
668+
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
669+
@pytest.mark.parametrize("transpose", [False, True])
670+
@pytest.mark.parametrize(
671+
"shape",
672+
(
673+
(128, 64),
674+
(1, 128, 64),
675+
),
676+
)
677+
def test_scale_shape_matches_qdata(transpose, shape):
678+
if len(shape) == 3 and transpose:
679+
pytest.skip("transpose not yet implemented for 3D MXTensor")
680+
681+
block_size = 32
682+
683+
x_hp = torch.randn(*shape, device="cuda")
684+
x = MXTensor.to_mx(
685+
x_hp,
686+
torch.float8_e4m3fn,
687+
block_size,
688+
ScaleCalculationMode.FLOOR,
689+
)
690+
691+
if len(shape) == 2:
692+
m_dim, k_dim = 0, 1
693+
if transpose:
694+
x_hp = x_hp.t()
695+
x = x.t()
696+
m_dim, k_dim = 1, 0
697+
else:
698+
assert len(shape) == 3, "unsupported"
699+
m_dim, k_dim = 1, 2
700+
if transpose:
701+
x_hp = x_hp.transpose(-2, -1)
702+
x = x.transpose(-2, -1)
703+
m_dim, k_dim = 2, 1
704+
705+
orig_m = x_hp.shape[m_dim]
706+
expected_padded_m = orig_m
707+
actual_padded_m = x.scale.shape[m_dim]
708+
assert expected_padded_m == actual_padded_m, (
709+
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
710+
)
711+
712+
orig_k = x_hp.shape[k_dim]
713+
expected_padded_k = orig_k // block_size
714+
actual_padded_k = x.scale.shape[k_dim]
715+
716+
assert expected_padded_k == actual_padded_k, (
717+
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
718+
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ def triton_to_mxfp8_dim1(
12641264

12651265
return (
12661266
output_col_major.t(),
1267-
col_scale.view(torch.float8_e8m0fnu),
1267+
col_scale.view(torch.float8_e8m0fnu).squeeze(-1),
12681268
)
12691269

12701270
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
@@ -1293,7 +1293,7 @@ def triton_to_mxfp8_dim1_reference(
12931293
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
12941294
return (
12951295
x_hp_d1_normalized.t(),
1296-
scale_e8m0_dim1.unsqueeze(-1),
1296+
scale_e8m0_dim1,
12971297
)
12981298

12991299
@triton.jit

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def to_dtype(
362362
# unpacking and unscaling
363363
if is_transposed:
364364
data_lp = data_lp.t()
365+
scale_e8m0 = scale_e8m0.t()
365366
assert data_lp.is_contiguous()
366367
orig_shape = (orig_shape[1], orig_shape[0])
367368

@@ -688,7 +689,7 @@ def _addmm_mx_dispatch(
688689
assert b._block_size == 32, f"Invalid block size {b._block_size}"
689690

690691
a_scale = a.scale.view(M, K // a._block_size)
691-
b_scale = b.scale.view(N, K // b._block_size)
692+
b_scale = b.scale.t().view(N, K // b._block_size)
692693
a_scale_block = to_blocked(a_scale)
693694
b_scale_block = to_blocked(b_scale)
694695

@@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs):
759760
old = args[0]
760761
new = MXTensor(
761762
old.qdata.t(),
762-
old.scale,
763+
old.scale.t(),
763764
old._elem_dtype,
764765
old._block_size,
765766
old._orig_dtype,

0 commit comments

Comments
 (0)