Skip to content

Commit 7211104

Browse files
authored
extend mxfp8 roofline with more recipes (#3190)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 4b79f9e commit 7211104

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def get_gemm_times(
180180
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
181181
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
182182
else:
183-
assert False, "TODO add cutlass mx gemm here"
183+
assert False, f"unsupported {float8_recipe_name=} {mx_recipe_name=}"
184184

185185
def do_matmul(A, B):
186186
return torch._scaled_mm(
@@ -233,6 +233,20 @@ def run(
233233
print(f"mx_recipe_name: {mx_recipe_name}")
234234
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
235235

236+
assert mx_recipe_name in (
237+
# real mxfp8_cublas recipe
238+
"mxfp8_cublas",
239+
# real mxfp8_cublas_rceil recipe
240+
"mxfp8_cublas_rceil",
241+
# modeling of what mxfp8 with 32x32 block size and without gemm
242+
# operand layout restrictions would look like
243+
"mxfp8_32x32_flexible_gemm_layout",
244+
# modeling of what mxfp8 with 32x32 block size for weight
245+
"mxfp8_32x32_weight",
246+
# real mxfp4_cutlass recipe
247+
"mxfp4_cutlass",
248+
), f"unsupported {mx_recipe_name=}"
249+
236250
M, K, N = sympy.symbols("M K N")
237251

238252
fp8_ovhd_time_sympy = get_float8_mem_sympy(
@@ -309,7 +323,11 @@ def run(
309323
rb_fp8_gemm_ratio = -1
310324

311325
if do_benchmarks:
312-
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
326+
assert mx_recipe_name not in (
327+
"mxfp4_cutlass",
328+
"mxfp8_32x32_flexible_gemm_layout",
329+
"mxfp8_32x32_weight",
330+
), f"do_benchmarks unsupported with {mx_recipe_name=}"
313331

314332
# TODO(future): make the bf16 gemm times exactly match the e2e
315333
# benchmarks, there is a slight deviation, probably related to gemm

torchao/testing/training/roofline_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,53 @@ def get_tensor_memory_traffic_ovhd_s(
187187
else:
188188
assert False, "unsupported"
189189

190+
elif mx_recipe_name == "mxfp8_32x32_flexible_gemm_layout":
191+
# modeling the following:
192+
# 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense
193+
# across dim0 and dim1
194+
# 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in
195+
# PyTorch right now)
196+
# x_bf16 = ...
197+
# kernel 1: x_bf16 -> x_mxfp8_dim0
198+
if fuse_with_prev:
199+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
200+
else:
201+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
202+
res_bytes = [kernel_1_rw]
203+
204+
elif mx_recipe_name == "mxfp8_32x32_weight":
205+
# modeling the following:
206+
# 1. mxfp8 scaling with 32x32 weights, so the format makes sense
207+
# across dim0 and dim1. input and grad_output still 1x32.
208+
209+
if tensor_role in ("input", "grad_output"):
210+
# TODO(future): update all of the mx rooflines to just read once
211+
# kernel 1: x_bf16 -> x_mxfp8_dim0
212+
# kernel 2: x_bf16 -> x_mxfp8_dim1
213+
if fuse_with_prev:
214+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
215+
else:
216+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
217+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
218+
219+
elif tensor_role == "weight":
220+
# kernel 1: x_bf16 -> x_mxfp8_dim0
221+
# kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1
222+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
223+
kernel_2_rw = BYTES_PER_EL_FLOAT8 * numel * 2
224+
225+
else:
226+
assert False, "unsupported"
227+
228+
res_bytes = [kernel_1_rw, kernel_2_rw]
229+
190230
else:
191231
assert mx_recipe_name in (
192232
"mxfp8_emulated",
193233
"mxfp8_cublas",
194234
"mxfp8_cublas_rceil",
195235
"mxfp4_cutlass",
196-
), "unsupported"
236+
), f"unsupported {mx_recipe_name=}"
197237
# For now, assume that we can't profitably fuse kernel 1 and kernel 2
198238
# x_bf16 = ...
199239
# kernel 1: x_bf16 -> x_mxfp8_dim0

0 commit comments

Comments
 (0)