Skip to content

Commit 7537d99

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 dequantization (#3195)
1 parent 41a0778 commit 7537d99

File tree

3 files changed

+314
-12
lines changed

3 files changed

+314
-12
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.mx_formats.kernels import triton_mxfp8_dequant_dim0
17+
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
18+
19+
device = torch.device("cuda")
20+
21+
# Needed since changing args to function causes recompiles
22+
torch._dynamo.config.cache_size_limit = 1000
23+
24+
25+
@dataclass(frozen=True)
26+
class ExperimentConfig:
27+
input_shape: tuple[int]
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentResult:
32+
# time
33+
torch_us: float
34+
triton_us: float
35+
torch_gbps: float
36+
triton_gbps: float
37+
38+
39+
@dataclass(frozen=True)
40+
class Experiment:
41+
config: ExperimentConfig
42+
result: ExperimentResult
43+
44+
45+
def get_configs() -> List[ExperimentConfig]:
46+
input_shapes = [
47+
# (local_batch_size, seq_len, dim)
48+
(1, 8192, 7168),
49+
(2, 8192, 7168),
50+
(4, 8192, 7168),
51+
(8, 8192, 7168),
52+
]
53+
configs = []
54+
for shape in input_shapes:
55+
configs.append(
56+
ExperimentConfig(
57+
input_shape=shape,
58+
)
59+
)
60+
return configs
61+
62+
63+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
64+
block_size = 32
65+
input_shape = config.input_shape
66+
input_tensor = torch.randn(
67+
*input_shape,
68+
dtype=torch.bfloat16,
69+
device=device,
70+
)
71+
72+
e8m0_scales, e4m3_data = to_mx(input_tensor, torch.float8_e4m3fn, block_size)
73+
74+
# Bench torch dequant
75+
to_dtype_c = torch.compile(to_dtype)
76+
elem_dtype, target_dtype = torch.float8_e4m3fn, torch.bfloat16
77+
torch_output = to_dtype_c(
78+
e4m3_data,
79+
e8m0_scales,
80+
elem_dtype,
81+
block_size,
82+
target_dtype,
83+
)
84+
torch_us = benchmark_cuda_function_in_microseconds(
85+
to_dtype_c,
86+
e4m3_data,
87+
e8m0_scales,
88+
elem_dtype,
89+
block_size,
90+
target_dtype,
91+
)
92+
93+
# Bench triton kernel
94+
_ = triton_mxfp8_dequant_dim0(
95+
e4m3_data,
96+
e8m0_scales,
97+
target_dtype,
98+
block_size,
99+
)
100+
triton_us = benchmark_cuda_function_in_microseconds(
101+
triton_mxfp8_dequant_dim0,
102+
e4m3_data,
103+
e8m0_scales,
104+
target_dtype,
105+
block_size,
106+
)
107+
108+
# mem bw calculations
109+
bytes_per_input_el = torch.finfo(elem_dtype).bits / 8
110+
bytes_per_output_el = torch.finfo(target_dtype).bits / 8
111+
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
112+
113+
read_bytes = (
114+
e4m3_data.numel() * bytes_per_input_el
115+
+ e8m0_scales.numel() * bytes_per_scale_el
116+
)
117+
write_bytes = torch_output.numel() * bytes_per_output_el
118+
119+
torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_us / 1e6)
120+
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_us / 1e6)
121+
122+
return ExperimentResult(
123+
torch_us=torch_us,
124+
triton_us=triton_us,
125+
triton_gbps=triton_gbps,
126+
torch_gbps=torch_gbps,
127+
)
128+
129+
130+
def print_results(experiments: List[Experiment]):
131+
headers = [
132+
"input_shape",
133+
"torch_us",
134+
"triton_us",
135+
"torch_gbps",
136+
"triton_gbps",
137+
"triton_speedup",
138+
]
139+
rows = []
140+
for experiment in experiments:
141+
triton_speedup = round(
142+
experiment.result.torch_us / experiment.result.triton_us, 3
143+
)
144+
rows.append(
145+
[
146+
str(experiment.config.input_shape),
147+
experiment.result.torch_us,
148+
experiment.result.triton_us,
149+
round(experiment.result.torch_gbps, 3),
150+
round(experiment.result.triton_gbps, 3),
151+
f"{triton_speedup}x",
152+
]
153+
)
154+
print(tabulate(rows, headers=headers))
155+
156+
157+
def main():
158+
torch.random.manual_seed(123)
159+
configs = get_configs()
160+
results = []
161+
for config in tqdm(configs):
162+
result = run_experiment(config)
163+
results.append(Experiment(config=config, result=result))
164+
165+
# Use Tabulate to print results
166+
print_results(results)
167+
168+
169+
if __name__ == "__main__":
170+
main()

test/prototype/mx_formats/test_kernels.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_mxfp8_dequant_dim0,
4041
triton_to_mxfp8_dim0,
4142
triton_to_mxfp8_dim1,
4243
triton_to_mxfp8_dim1_reference,
4344
unpack_uint4,
4445
)
45-
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
46+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
4647
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
is_sm_at_least_89,
@@ -513,6 +514,28 @@ def test_triton_mxfp8_dim0_zeros():
513514
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
514515

515516

517+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
518+
@pytest.mark.skipif(
519+
not is_sm_at_least_100(),
520+
reason="mxfp8 requires CUDA capability 10.0 or greater",
521+
)
522+
@pytest.mark.parametrize("M", (256, 2048, 131072))
523+
@pytest.mark.parametrize("K", (256, 5120, 7168))
524+
def test_triton_mxfp8_dequant_dim0(M, K):
525+
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
526+
block_size = 32
527+
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
528+
hp_ref = to_dtype(
529+
x_data,
530+
x_scales,
531+
torch.float8_e4m3fn,
532+
block_size,
533+
torch.bfloat16,
534+
)
535+
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
536+
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
537+
538+
516539
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
517540
@pytest.mark.parametrize(
518541
"shape",

torchao/prototype/mx_formats/kernels.py

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -880,17 +880,19 @@ def _get_mxfp8_quant_autotune_configs():
880880
# sweep over a small set of shapes, it's likely that this
881881
# can be improved in the future.
882882
results = []
883-
for ROW_TILE_SIZE in (64, 128):
884-
for COL_TILE_SIZE in (64, 128):
885-
for num_warps in (1, 2, 4):
886-
config = triton.Config(
887-
{
888-
"ROW_TILE_SIZE": ROW_TILE_SIZE,
889-
"COL_TILE_SIZE": COL_TILE_SIZE,
890-
},
891-
num_warps=num_warps,
892-
)
893-
results.append(config)
883+
for ROW_TILE_SIZE in (128, 256, 512):
884+
for COL_TILE_SIZE in (128, 256, 512):
885+
for num_warps in (4, 8):
886+
for num_stages in (2, 3):
887+
config = triton.Config(
888+
{
889+
"ROW_TILE_SIZE": ROW_TILE_SIZE,
890+
"COL_TILE_SIZE": COL_TILE_SIZE,
891+
},
892+
num_warps=num_warps,
893+
num_stages=num_stages,
894+
)
895+
results.append(config)
894896
return results
895897

896898
@triton.autotune(
@@ -1277,6 +1279,105 @@ def triton_to_mxfp8_dim1_reference(
12771279
scale_e8m0_dim1,
12781280
)
12791281

1282+
def triton_mxfp8_dequant_dim0(
1283+
e4m3_data: torch.Tensor,
1284+
e8m0_scales: torch.Tensor,
1285+
out_dtype: torch.dtype,
1286+
scale_block_size: int = 32,
1287+
) -> None:
1288+
assert scale_block_size == 32, "scale_block_size must be 32 for now"
1289+
assert out_dtype in (torch.bfloat16, torch.float32), (
1290+
"out_dtype must be bf16 or fp32"
1291+
)
1292+
1293+
# Input shape must be 2D.
1294+
orig_shape = e4m3_data.shape
1295+
e4m3_data = e4m3_data.reshape(-1, orig_shape[-1])
1296+
out_buffer = torch.empty_like(e4m3_data, dtype=out_dtype)
1297+
out_dtype_tl = tl.bfloat16 if out_dtype == torch.bfloat16 else tl.float32
1298+
1299+
grid = lambda META: (
1300+
triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]),
1301+
triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]),
1302+
)
1303+
_dequant_mxfp8_kernel[grid](
1304+
e4m3_data,
1305+
e8m0_scales.to(torch.uint8),
1306+
out_buffer,
1307+
e4m3_data.size(0),
1308+
e4m3_data.size(1),
1309+
e8m0_scales.size(0),
1310+
e8m0_scales.size(1),
1311+
out_dtype=out_dtype_tl,
1312+
SCALE_BLOCK_SIZE=scale_block_size,
1313+
)
1314+
return out_buffer.reshape(orig_shape)
1315+
1316+
@triton.autotune(
1317+
configs=_get_mxfp8_quant_autotune_configs(),
1318+
key=["input_num_cols", "SCALE_BLOCK_SIZE"],
1319+
)
1320+
@triton.jit
1321+
def _dequant_mxfp8_kernel(
1322+
e4m3_data,
1323+
e8m0_scales,
1324+
out_buffer,
1325+
input_num_rows,
1326+
input_num_cols,
1327+
scale_num_rows,
1328+
scale_num_cols,
1329+
out_dtype: tl.constexpr,
1330+
SCALE_BLOCK_SIZE: tl.constexpr,
1331+
ROW_TILE_SIZE: tl.constexpr,
1332+
COL_TILE_SIZE: tl.constexpr,
1333+
):
1334+
pid_row = tl.program_id(0)
1335+
pid_col = tl.program_id(1)
1336+
SCALE_BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // SCALE_BLOCK_SIZE
1337+
1338+
# Load block of e4m3 data
1339+
row_offs = pid_row * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)
1340+
col_offs = pid_col * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)
1341+
block_offs = row_offs[:, None] * input_num_cols + col_offs[None, :]
1342+
mask = (row_offs[:, None] < input_num_rows) & (
1343+
col_offs[None, :] < input_num_cols
1344+
)
1345+
e4m3_data_block = tl.load(e4m3_data + block_offs, mask=mask)
1346+
1347+
# Load block of e8m0 scales
1348+
scale_col_offs = pid_col * SCALE_BLOCKS_PER_COL_TILE + tl.arange(
1349+
0, SCALE_BLOCKS_PER_COL_TILE
1350+
)
1351+
scale_block_offs = row_offs[:, None] * scale_num_cols + scale_col_offs[None, :]
1352+
scale_mask = (row_offs[:, None] < scale_num_rows) & (
1353+
scale_col_offs[None, :] < scale_num_cols
1354+
)
1355+
e8m0_scale_block = tl.load(e8m0_scales + scale_block_offs, mask=scale_mask)
1356+
1357+
# Dequantize and return output
1358+
e4m3_data_block_r = e4m3_data_block.reshape(
1359+
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE
1360+
)
1361+
e8m0_scale_block_r = e8m0_scale_block.reshape(
1362+
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, 1
1363+
)
1364+
fp32_scale = _e8m0_to_fp32(e8m0_scale_block_r)
1365+
data_hp = e4m3_data_block_r.to(tl.float32) * fp32_scale
1366+
1367+
# Write to output buffer
1368+
out_buffer_block = data_hp.to(out_dtype)
1369+
out_buffer_block = out_buffer_block.reshape(ROW_TILE_SIZE, COL_TILE_SIZE)
1370+
tl.store(out_buffer + block_offs, out_buffer_block, mask=mask)
1371+
1372+
@triton.jit
1373+
def _e8m0_to_fp32(scale_e8m0):
1374+
e8m0_exponent_bias = 127
1375+
e8m0_nan_val = 255
1376+
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
1377+
s_fp = tl.exp2(s_offset.to(tl.float32))
1378+
s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))
1379+
return s_fp.to(tl.float32)
1380+
12801381
@triton.jit
12811382
def triton_scale_swizzle(
12821383
scale_ptr,
@@ -1641,6 +1742,14 @@ def triton_quantize_nvfp4(
16411742
) -> Tuple[torch.Tensor, torch.Tensor]:
16421743
raise AssertionError("needs torch version 2.8+ and triton")
16431744

1745+
def triton_mxfp8_dequant_dim0(
1746+
e4m3_data: torch.Tensor,
1747+
e8m0_scales: torch.Tensor,
1748+
out_dtype: torch.dtype,
1749+
inner_block_size=32,
1750+
) -> torch.Tensor:
1751+
raise AssertionError("needs torch version 2.8+ and triton")
1752+
16441753

16451754
mxfp8_cuda_extension_available = False
16461755
if is_sm_at_least_100():

0 commit comments

Comments
 (0)