Skip to content

Commit 6a2088a

Browse files
[mxfp8 moe training] simplify e8m0 -> fp32 calc
stack-info: PR: #3201, branch: danielvegamyhre/stack/80
1 parent 710192d commit 6a2088a

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,9 @@ def test_triton_mxfp8_dim1_randn(M, K):
492492
)
493493
@pytest.mark.parametrize("M", (256, 2048, 131072))
494494
@pytest.mark.parametrize("K", (256, 5120, 7168))
495-
def test_triton_mxfp8_dim0_randn(M, K):
496-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
495+
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
496+
def test_triton_mxfp8_dim0_randn(M, K, orig_dtype):
497+
x = torch.randn(M, K, dtype=orig_dtype, device="cuda")
497498
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
498499
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
499500
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
@@ -521,18 +522,19 @@ def test_triton_mxfp8_dim0_zeros():
521522
)
522523
@pytest.mark.parametrize("M", (256, 2048, 131072))
523524
@pytest.mark.parametrize("K", (256, 5120, 7168))
524-
def test_triton_mxfp8_dequant_dim0(M, K):
525-
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
525+
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
526+
def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
527+
x = torch.zeros(M, K, dtype=orig_dtype, device="cuda")
526528
block_size = 32
527529
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
528530
hp_ref = to_dtype(
529531
x_data,
530532
x_scales,
531533
torch.float8_e4m3fn,
532534
block_size,
533-
torch.bfloat16,
535+
orig_dtype,
534536
)
535-
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
537+
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size)
536538
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
537539

538540

torchao/prototype/mx_formats/kernels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,12 +1279,13 @@ def triton_to_mxfp8_dim1_reference(
12791279
scale_e8m0_dim1.unsqueeze(-1),
12801280
)
12811281

1282+
@triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={})
12821283
def triton_mxfp8_dequant_dim0(
12831284
e4m3_data: torch.Tensor,
12841285
e8m0_scales: torch.Tensor,
12851286
out_dtype: torch.dtype,
12861287
scale_block_size: int = 32,
1287-
) -> None:
1288+
) -> torch.Tensor:
12881289
assert scale_block_size == 32, "scale_block_size must be 32 for now"
12891290
assert out_dtype in (torch.bfloat16, torch.float32), (
12901291
"out_dtype must be bf16 or fp32"
@@ -1300,7 +1301,7 @@ def triton_mxfp8_dequant_dim0(
13001301
triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]),
13011302
triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]),
13021303
)
1303-
_dequant_mxfp8_kernel[grid](
1304+
wrap_triton(_dequant_mxfp8_kernel)[grid](
13041305
e4m3_data,
13051306
e8m0_scales.to(torch.uint8),
13061307
out_buffer,
@@ -1371,8 +1372,8 @@ def _dequant_mxfp8_kernel(
13711372

13721373
@triton.jit
13731374
def _e8m0_to_fp32(scale_e8m0):
1374-
e8m0_exponent_bias = 127
13751375
e8m0_nan_val = 255
1376+
e8m0_exponent_bias = 127
13761377
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
13771378
s_fp = tl.exp2(s_offset.to(tl.float32))
13781379
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))

0 commit comments

Comments
 (0)