Skip to content

Commit bbfd981

Browse files
authored
mxtensor: make scale shape match qdata (#3198)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent b50e37a commit bbfd981

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,57 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
662662
rtol=0.0,
663663
msg=f"Roundtrip failed for shape {shape} with use_triton_kernel={use_triton_kernel}",
664664
)
665+
666+
667+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
668+
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
669+
@pytest.mark.parametrize("transpose", [False, True])
670+
@pytest.mark.parametrize(
671+
"shape",
672+
(
673+
(128, 64),
674+
(1, 128, 64),
675+
),
676+
)
677+
def test_scale_shape_matches_qdata(transpose, shape):
678+
if len(shape) == 3 and transpose:
679+
pytest.skip("transpose not yet implemented for 3D MXTensor")
680+
681+
block_size = 32
682+
683+
x_hp = torch.randn(*shape, device="cuda")
684+
x = MXTensor.to_mx(
685+
x_hp,
686+
torch.float8_e4m3fn,
687+
block_size,
688+
ScaleCalculationMode.FLOOR,
689+
)
690+
691+
if len(shape) == 2:
692+
m_dim, k_dim = 0, 1
693+
if transpose:
694+
x_hp = x_hp.t()
695+
x = x.t()
696+
m_dim, k_dim = 1, 0
697+
else:
698+
assert len(shape) == 3, "unsupported"
699+
m_dim, k_dim = 1, 2
700+
if transpose:
701+
x_hp = x_hp.transpose(-2, -1)
702+
x = x.transpose(-2, -1)
703+
m_dim, k_dim = 2, 1
704+
705+
orig_m = x_hp.shape[m_dim]
706+
expected_padded_m = orig_m
707+
actual_padded_m = x.scale.shape[m_dim]
708+
assert expected_padded_m == actual_padded_m, (
709+
f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x.scale.shape}"
710+
)
711+
712+
orig_k = x_hp.shape[k_dim]
713+
expected_padded_k = orig_k // block_size
714+
actual_padded_k = x.scale.shape[k_dim]
715+
716+
assert expected_padded_k == actual_padded_k, (
717+
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x.scale.shape}"
718+
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,7 +1245,7 @@ def triton_to_mxfp8_dim1(
12451245

12461246
return (
12471247
output_col_major.t(),
1248-
col_scale.view(torch.float8_e8m0fnu),
1248+
col_scale.view(torch.float8_e8m0fnu).squeeze(-1),
12491249
)
12501250

12511251
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
@@ -1274,7 +1274,7 @@ def triton_to_mxfp8_dim1_reference(
12741274
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
12751275
return (
12761276
x_hp_d1_normalized.t(),
1277-
scale_e8m0_dim1.unsqueeze(-1),
1277+
scale_e8m0_dim1,
12781278
)
12791279

12801280
@triton.jit

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def to_dtype(
362362
# unpacking and unscaling
363363
if is_transposed:
364364
data_lp = data_lp.t()
365+
scale_e8m0 = scale_e8m0.t()
365366
assert data_lp.is_contiguous()
366367
orig_shape = (orig_shape[1], orig_shape[0])
367368

@@ -688,7 +689,7 @@ def _addmm_mx_dispatch(
688689
assert b._block_size == 32, f"Invalid block size {b._block_size}"
689690

690691
a_scale = a.scale.view(M, K // a._block_size)
691-
b_scale = b.scale.view(N, K // b._block_size)
692+
b_scale = b.scale.t().view(N, K // b._block_size)
692693
a_scale_block = to_blocked(a_scale)
693694
b_scale_block = to_blocked(b_scale)
694695

@@ -759,7 +760,7 @@ def mx_t(func, types, args, kwargs):
759760
old = args[0]
760761
new = MXTensor(
761762
old.qdata.t(),
762-
old.scale,
763+
old.scale.t(),
763764
old._elem_dtype,
764765
old._block_size,
765766
old._orig_dtype,

0 commit comments

Comments
 (0)