@@ -1141,7 +1141,8 @@ def triton_to_mxfp8_dim0(
11411141        * `scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 
11421142        """ 
11431143        assert  x .is_contiguous (), "`x` must be contiguous" 
1144-         assert  inner_block_size  <=  32 
1144+         assert  inner_block_size  <=  32 , "inner_block_size must be <= 32" 
1145+         assert  x .dtype  ==  torch .bfloat16 , "only bfloat16 inputs are supported" 
11451146
11461147        # Reshape tensor to 2d if necessary and get shape 
11471148        x_orig_shape  =  x .shape 
@@ -1279,12 +1280,13 @@ def triton_to_mxfp8_dim1_reference(
12791280            scale_e8m0_dim1 ,
12801281        )
12811282
1283+     @triton_op ("torchao::triton_mxfp8_dequant_dim0" , mutates_args = {}) 
12821284    def  triton_mxfp8_dequant_dim0 (
12831285        e4m3_data : torch .Tensor ,
12841286        e8m0_scales : torch .Tensor ,
12851287        out_dtype : torch .dtype ,
12861288        scale_block_size : int  =  32 ,
1287-     ) ->  None :
1289+     ) ->  torch . Tensor :
12881290        assert  scale_block_size  ==  32 , "scale_block_size must be 32 for now" 
12891291        assert  out_dtype  in  (torch .bfloat16 , torch .float32 ), (
12901292            "out_dtype must be bf16 or fp32" 
@@ -1300,7 +1302,7 @@ def triton_mxfp8_dequant_dim0(
13001302            triton .cdiv (e4m3_data .shape [0 ], META ["ROW_TILE_SIZE" ]),
13011303            triton .cdiv (e4m3_data .shape [1 ], META ["COL_TILE_SIZE" ]),
13021304        )
1303-         _dequant_mxfp8_kernel [grid ](
1305+         wrap_triton ( _dequant_mxfp8_kernel ) [grid ](
13041306            e4m3_data ,
13051307            e8m0_scales .to (torch .uint8 ),
13061308            out_buffer ,
@@ -1371,8 +1373,8 @@ def _dequant_mxfp8_kernel(
13711373
13721374    @triton .jit  
13731375    def  _e8m0_to_fp32 (scale_e8m0 ):
1374-         e8m0_exponent_bias  =  127 
13751376        e8m0_nan_val  =  255 
1377+         e8m0_exponent_bias  =  127 
13761378        s_offset  =  scale_e8m0 .to (tl .int16 ) -  e8m0_exponent_bias 
13771379        s_fp  =  tl .exp2 (s_offset .to (tl .float32 ))
13781380        s_fp  =  tl .where (scale_e8m0  !=  e8m0_nan_val , s_fp , float ("nan" ))
0 commit comments