Skip to content

Commit 357d20f

Browse files
[mxfp8 moe training] add triton kernel for mxfp8 dequantization
stack-info: PR: #3195, branch: danielvegamyhre/stack/78
1 parent b644211 commit 357d20f

File tree

6 files changed

+333
-34
lines changed

6 files changed

+333
-34
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def default_a2a_fwd_bwd(
8484

8585
loss = F.mse_loss(routed_input, labels)
8686
loss.backward()
87-
8887
torch.cuda.synchronize()
8988
return routed_input
9089

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.mx_formats.kernels import triton_mxfp8_dequant_dim0
17+
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
18+
19+
device = torch.device("cuda")
20+
21+
# Needed since changing args to function causes recompiles
22+
torch._dynamo.config.cache_size_limit = 1000
23+
24+
25+
@dataclass(frozen=True)
26+
class ExperimentConfig:
27+
input_shape: tuple[int]
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentResult:
32+
# time
33+
torch_us: float
34+
triton_us: float
35+
torch_gbps: float
36+
triton_gbps: float
37+
38+
39+
@dataclass(frozen=True)
40+
class Experiment:
41+
config: ExperimentConfig
42+
result: ExperimentResult
43+
44+
45+
def get_configs() -> List[ExperimentConfig]:
46+
input_shapes = [
47+
# (local_batch_size, seq_len, dim)
48+
(1, 8192, 7168),
49+
(2, 8192, 7168),
50+
(4, 8192, 7168),
51+
(8, 8192, 7168),
52+
]
53+
configs = []
54+
for shape in input_shapes:
55+
configs.append(
56+
ExperimentConfig(
57+
input_shape=shape,
58+
)
59+
)
60+
return configs
61+
62+
63+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
64+
block_size = 32
65+
input_shape = config.input_shape
66+
input_tensor = torch.randn(
67+
*input_shape,
68+
dtype=torch.bfloat16,
69+
device=device,
70+
)
71+
72+
e8m0_scales, e4m3_data = to_mx(input_tensor, torch.float8_e4m3fn, block_size)
73+
74+
# Bench torch dequant
75+
to_dtype_c = torch.compile(to_dtype)
76+
elem_dtype, target_dtype = torch.float8_e4m3fn, torch.bfloat16
77+
torch_output = to_dtype_c(
78+
e4m3_data,
79+
e8m0_scales,
80+
elem_dtype,
81+
block_size,
82+
target_dtype,
83+
)
84+
torch_us = benchmark_cuda_function_in_microseconds(
85+
to_dtype_c,
86+
e4m3_data,
87+
e8m0_scales,
88+
elem_dtype,
89+
block_size,
90+
target_dtype,
91+
)
92+
93+
# Bench triton kernel
94+
_ = triton_mxfp8_dequant_dim0(
95+
e4m3_data,
96+
e8m0_scales,
97+
target_dtype,
98+
block_size,
99+
)
100+
triton_us = benchmark_cuda_function_in_microseconds(
101+
triton_mxfp8_dequant_dim0,
102+
e4m3_data,
103+
e8m0_scales,
104+
target_dtype,
105+
block_size,
106+
)
107+
108+
# mem bw calculations
109+
bytes_per_input_el = torch.finfo(elem_dtype).bits / 8
110+
bytes_per_output_el = torch.finfo(target_dtype).bits / 8
111+
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
112+
113+
read_bytes = (
114+
e4m3_data.numel() * bytes_per_input_el
115+
+ e8m0_scales.numel() * bytes_per_scale_el
116+
)
117+
write_bytes = torch_output.numel() * bytes_per_output_el
118+
119+
torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_us / 1e6)
120+
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_us / 1e6)
121+
122+
return ExperimentResult(
123+
torch_us=torch_us,
124+
triton_us=triton_us,
125+
triton_gbps=triton_gbps,
126+
torch_gbps=torch_gbps,
127+
)
128+
129+
130+
def print_results(experiments: List[Experiment]):
131+
headers = [
132+
"input_shape",
133+
"torch_us",
134+
"triton_us",
135+
"torch_gbps",
136+
"triton_gbps",
137+
"triton_speedup",
138+
]
139+
rows = []
140+
for experiment in experiments:
141+
triton_speedup = round(
142+
experiment.result.torch_us / experiment.result.triton_us, 3
143+
)
144+
rows.append(
145+
[
146+
str(experiment.config.input_shape),
147+
experiment.result.torch_us,
148+
experiment.result.triton_us,
149+
round(experiment.result.torch_gbps, 3),
150+
round(experiment.result.triton_gbps, 3),
151+
f"{triton_speedup}x",
152+
]
153+
)
154+
print(tabulate(rows, headers=headers))
155+
156+
157+
def main():
158+
torch.random.manual_seed(123)
159+
configs = get_configs()
160+
results = []
161+
for config in tqdm(configs):
162+
result = run_experiment(config)
163+
results.append(Experiment(config=config, result=result))
164+
165+
# Use Tabulate to print results
166+
print_results(results)
167+
168+
169+
if __name__ == "__main__":
170+
main()

test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
compute_error,
2424
)
2525
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
26-
mxfp8_on_device_all_to_all_v,
2726
to_mxfp8_a2a_dequant,
27+
to_mxfp8_on_device_a2a_dequant,
2828
)
2929

3030
from ..testing_utils import generate_split_sizes
@@ -88,7 +88,7 @@ def test_a2a_fwd_bwd(self):
8888
max_output_tokens_per_rank = tokens_per_ep_rank * self.world_size
8989

9090
# Test forward
91-
output, output_splits = mxfp8_on_device_all_to_all_v(
91+
output, output_splits = to_mxfp8_on_device_a2a_dequant(
9292
input_tensor,
9393
input_splits,
9494
max_output_tokens_per_rank,

test/prototype/mx_formats/test_kernels.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@
3737
pack_uint6,
3838
triton_f6_e2m3_to_bf16,
3939
triton_f6_e3m2_to_bf16,
40+
triton_mxfp8_dequant_dim0,
4041
triton_to_mxfp8_dim0,
4142
triton_to_mxfp8_dim1,
4243
triton_to_mxfp8_dim1_reference,
4344
unpack_uint4,
4445
)
45-
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
46+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
4647
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
is_sm_at_least_89,
@@ -513,6 +514,28 @@ def test_triton_mxfp8_dim0_zeros():
513514
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
514515

515516

517+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
518+
@pytest.mark.skipif(
519+
not is_sm_at_least_100(),
520+
reason="mxfp8 requires CUDA capability 10.0 or greater",
521+
)
522+
@pytest.mark.parametrize("M", (256, 2048, 131072))
523+
@pytest.mark.parametrize("K", (256, 5120, 7168))
524+
def test_triton_mxfp8_dequant_dim0(M, K):
525+
x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")
526+
block_size = 32
527+
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
528+
hp_ref = to_dtype(
529+
x_data,
530+
x_scales,
531+
torch.float8_e4m3fn,
532+
block_size,
533+
torch.bfloat16,
534+
)
535+
hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size)
536+
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
537+
538+
516539
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
517540
@pytest.mark.parametrize(
518541
"shape",

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111
blockwise_barrier,
1212
sync_threads,
1313
)
14-
from torchao.prototype.mx_formats.config import ScaleCalculationMode
15-
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
14+
from torchao.prototype.mx_formats.kernels import (
15+
triton_mxfp8_dequant_dim0,
16+
triton_to_mxfp8_dim0,
17+
)
18+
from torchao.prototype.mx_formats.mx_tensor import (
19+
to_dtype,
20+
to_mx,
21+
)
1622

1723

1824
# This performs dynamic mxfp8 quantization of the input tensor,
@@ -256,7 +262,7 @@ def backward(ctx, grad_output, grad_splits):
256262

257263

258264
# Alias
259-
mxfp8_on_device_all_to_all_v = MXFP8OnDeviceAllToAllV.apply
265+
to_mxfp8_on_device_a2a_dequant = MXFP8OnDeviceAllToAllV.apply
260266

261267

262268
# Triton launcher function
@@ -473,11 +479,9 @@ def forward(
473479
"""
474480
# Quantize input
475481
block_size = 32
476-
input_scales, input_data = to_mx(
482+
input_data, input_scales = triton_to_mxfp8_dim0(
477483
input,
478-
elem_dtype=torch.float8_e4m3fn,
479-
block_size=block_size,
480-
scaling_mode=ScaleCalculationMode.RCEIL,
484+
inner_block_size=block_size,
481485
)
482486

483487
# Dispatch data (async)
@@ -501,14 +505,12 @@ def forward(
501505
output_data = torch.ops._c10d_functional.wait_tensor(output_data)
502506

503507
# Dequantize output
504-
lowp_dtype = output_data.dtype
505508
hp_dtype = input.dtype
506-
hp_output = to_dtype(
509+
hp_output = triton_mxfp8_dequant_dim0(
507510
output_data,
508511
output_scales.view(torch.float8_e8m0fnu),
509-
lowp_dtype,
510-
block_size,
511512
hp_dtype,
513+
block_size,
512514
)
513515

514516
ctx.input_splits = input_splits
@@ -529,11 +531,9 @@ def backward(ctx, grad_output_hp):
529531

530532
# Quantize grad_output
531533
block_size = 32
532-
grad_out_scales, grad_out_data = to_mx(
534+
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
533535
grad_output_hp,
534-
elem_dtype=torch.float8_e4m3fn,
535-
block_size=block_size,
536-
scaling_mode=ScaleCalculationMode.RCEIL,
536+
inner_block_size=block_size,
537537
)
538538

539539
# Dispatch data (async)
@@ -557,13 +557,11 @@ def backward(ctx, grad_output_hp):
557557
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
558558

559559
hp_dtype = grad_output_hp.dtype
560-
lowp_dtype = grad_input_data.dtype
561-
grad_input_hp = to_dtype(
560+
grad_input_hp = triton_mxfp8_dequant_dim0(
562561
grad_input_data,
563562
grad_input_scales.view(torch.float8_e8m0fnu),
564-
lowp_dtype,
565-
block_size,
566563
hp_dtype,
564+
block_size,
567565
)
568566
return grad_input_hp, None, None, None
569567

0 commit comments

Comments
 (0)