Skip to content

Commit d089c6a

Browse files
[mxfp8 moe training] simplify e8m0 -> fp32 calc (#3201)
1 parent 6a62549 commit d089c6a

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,18 +521,19 @@ def test_triton_mxfp8_dim0_zeros():
521521
)
522522
@pytest.mark.parametrize("M", (256, 2048, 131072))
523523
@pytest.mark.parametrize("K", (256, 5120, 7168))
524-
def test_triton_mxfp8_dequant_dim0(M, K):
525-
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
524+
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
525+
def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
526+
x = torch.zeros(M, K, dtype=orig_dtype, device="cuda")
526527
block_size = 32
527528
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
528529
hp_ref = to_dtype(
529530
x_data,
530531
x_scales,
531532
torch.float8_e4m3fn,
532533
block_size,
533-
torch.bfloat16,
534+
orig_dtype,
534535
)
535-
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
536+
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size)
536537
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
537538

538539

torchao/prototype/mx_formats/kernels.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,8 @@ def triton_to_mxfp8_dim0(
11411141
* `scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
11421142
"""
11431143
assert x.is_contiguous(), "`x` must be contiguous"
1144-
assert inner_block_size <= 32
1144+
assert inner_block_size <= 32, "inner_block_size must be <= 32"
1145+
assert x.dtype == torch.bfloat16, "only bfloat16 inputs are supported"
11451146

11461147
# Reshape tensor to 2d if necessary and get shape
11471148
x_orig_shape = x.shape
@@ -1279,12 +1280,13 @@ def triton_to_mxfp8_dim1_reference(
12791280
scale_e8m0_dim1,
12801281
)
12811282

1283+
@triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={})
12821284
def triton_mxfp8_dequant_dim0(
12831285
e4m3_data: torch.Tensor,
12841286
e8m0_scales: torch.Tensor,
12851287
out_dtype: torch.dtype,
12861288
scale_block_size: int = 32,
1287-
) -> None:
1289+
) -> torch.Tensor:
12881290
assert scale_block_size == 32, "scale_block_size must be 32 for now"
12891291
assert out_dtype in (torch.bfloat16, torch.float32), (
12901292
"out_dtype must be bf16 or fp32"
@@ -1300,7 +1302,7 @@ def triton_mxfp8_dequant_dim0(
13001302
triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]),
13011303
triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]),
13021304
)
1303-
_dequant_mxfp8_kernel[grid](
1305+
wrap_triton(_dequant_mxfp8_kernel)[grid](
13041306
e4m3_data,
13051307
e8m0_scales.to(torch.uint8),
13061308
out_buffer,
@@ -1371,8 +1373,8 @@ def _dequant_mxfp8_kernel(
13711373

13721374
@triton.jit
13731375
def _e8m0_to_fp32(scale_e8m0):
1374-
e8m0_exponent_bias = 127
13751376
e8m0_nan_val = 255
1377+
e8m0_exponent_bias = 127
13761378
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
13771379
s_fp = tl.exp2(s_offset.to(tl.float32))
13781380
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))

0 commit comments

Comments
 (0)