@@ -1141,7 +1141,8 @@ def triton_to_mxfp8_dim0(
11411141 * `scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
11421142 """
11431143 assert x .is_contiguous (), "`x` must be contiguous"
1144- assert inner_block_size <= 32
1144+ assert inner_block_size <= 32 , "inner_block_size must be <= 32"
1145+ assert x .dtype == torch .bfloat16 , "only bfloat16 inputs are supported"
11451146
11461147 # Reshape tensor to 2d if necessary and get shape
11471148 x_orig_shape = x .shape
@@ -1279,12 +1280,13 @@ def triton_to_mxfp8_dim1_reference(
12791280 scale_e8m0_dim1 ,
12801281 )
12811282
1283+ @triton_op ("torchao::triton_mxfp8_dequant_dim0" , mutates_args = {})
12821284 def triton_mxfp8_dequant_dim0 (
12831285 e4m3_data : torch .Tensor ,
12841286 e8m0_scales : torch .Tensor ,
12851287 out_dtype : torch .dtype ,
12861288 scale_block_size : int = 32 ,
1287- ) -> None :
1289+ ) -> torch . Tensor :
12881290 assert scale_block_size == 32 , "scale_block_size must be 32 for now"
12891291 assert out_dtype in (torch .bfloat16 , torch .float32 ), (
12901292 "out_dtype must be bf16 or fp32"
@@ -1300,7 +1302,7 @@ def triton_mxfp8_dequant_dim0(
13001302 triton .cdiv (e4m3_data .shape [0 ], META ["ROW_TILE_SIZE" ]),
13011303 triton .cdiv (e4m3_data .shape [1 ], META ["COL_TILE_SIZE" ]),
13021304 )
1303- _dequant_mxfp8_kernel [grid ](
1305+ wrap_triton ( _dequant_mxfp8_kernel ) [grid ](
13041306 e4m3_data ,
13051307 e8m0_scales .to (torch .uint8 ),
13061308 out_buffer ,
@@ -1371,8 +1373,8 @@ def _dequant_mxfp8_kernel(
13711373
13721374 @triton .jit
13731375 def _e8m0_to_fp32 (scale_e8m0 ):
1374- e8m0_exponent_bias = 127
13751376 e8m0_nan_val = 255
1377+ e8m0_exponent_bias = 127
13761378 s_offset = scale_e8m0 .to (tl .int16 ) - e8m0_exponent_bias
13771379 s_fp = tl .exp2 (s_offset .to (tl .float32 ))
13781380 s_fp = tl .where (scale_e8m0 != e8m0_nan_val , s_fp , float ("nan" ))
0 commit comments