@@ -187,13 +187,53 @@ def get_tensor_memory_traffic_ovhd_s(
187187        else :
188188            assert  False , "unsupported" 
189189
190+     elif  mx_recipe_name  ==  "mxfp8_32x32_flexible_gemm_layout" :
191+         # modeling the following: 
192+         # 1. mxfp8 scaling with 32x32 everywhere, so the format makes sense 
193+         #    across dim0 and dim1 
194+         # 2. mxfp8 gemm with TN, NT, TT, NN formats supported (not in 
195+         #    PyTorch right now) 
196+         # x_bf16 = ... 
197+         # kernel 1:               x_bf16 -> x_mxfp8_dim0 
198+         if  fuse_with_prev :
199+             kernel_1_rw  =  0  +  BYTES_PER_EL_FLOAT8  *  numel 
200+         else :
201+             kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
202+         res_bytes  =  [kernel_1_rw ]
203+ 
204+     elif  mx_recipe_name  ==  "mxfp8_32x32_weight" :
205+         # modeling the following: 
206+         # 1. mxfp8 scaling with 32x32 weights, so the format makes sense 
207+         #    across dim0 and dim1. input and grad_output still 1x32. 
208+ 
209+         if  tensor_role  in  ("input" , "grad_output" ):
210+             # TODO(future): update all of the mx rooflines to just read once 
211+             # kernel 1: x_bf16 -> x_mxfp8_dim0 
212+             # kernel 2: x_bf16 -> x_mxfp8_dim1 
213+             if  fuse_with_prev :
214+                 kernel_1_rw  =  0  +  BYTES_PER_EL_FLOAT8  *  numel 
215+             else :
216+                 kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
217+             kernel_2_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
218+ 
219+         elif  tensor_role  ==  "weight" :
220+             # kernel 1: x_bf16 -> x_mxfp8_dim0 
221+             # kernel 2: x_mxfp8_dim0 -> x_mxfp8_dim1 
222+             kernel_1_rw  =  BYTES_PER_EL_BF16  *  numel  +  BYTES_PER_EL_FLOAT8  *  numel 
223+             kernel_2_rw  =  BYTES_PER_EL_FLOAT8  *  numel  *  2 
224+ 
225+         else :
226+             assert  False , "unsupported" 
227+ 
228+         res_bytes  =  [kernel_1_rw , kernel_2_rw ]
229+ 
190230    else :
191231        assert  mx_recipe_name  in  (
192232            "mxfp8_emulated" ,
193233            "mxfp8_cublas" ,
194234            "mxfp8_cublas_rceil" ,
195235            "mxfp4_cutlass" ,
196-         ), "unsupported" 
236+         ), f "unsupported  { mx_recipe_name = } 
197237        # For now, assume that we can't profitably fuse kernel 1 and kernel 2 
198238        # x_bf16 = ... 
199239        # kernel 1:               x_bf16 -> x_mxfp8_dim0 
0 commit comments