@@ -492,8 +492,9 @@ def test_triton_mxfp8_dim1_randn(M, K):
492492) 
493493@pytest .mark .parametrize ("M" , (256 , 2048 , 131072 )) 
494494@pytest .mark .parametrize ("K" , (256 , 5120 , 7168 )) 
495- def  test_triton_mxfp8_dim0_randn (M , K ):
496-     x  =  torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
495+ @pytest .mark .parametrize ("orig_dtype" , (torch .float32 , torch .bfloat16 )) 
496+ def  test_triton_mxfp8_dim0_randn (M , K , orig_dtype ):
497+     x  =  torch .randn (M , K , dtype = orig_dtype , device = "cuda" )
497498    x_mx_ref , x_s_ref  =  triton_to_mxfp8_dim0_reference (x , block_size = 32 )
498499    x_mx_t , x_s_t  =  triton_to_mxfp8_dim0 (x , inner_block_size = 32 )
499500    torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
@@ -521,18 +522,19 @@ def test_triton_mxfp8_dim0_zeros():
521522) 
522523@pytest .mark .parametrize ("M" , (256 , 2048 , 131072 )) 
523524@pytest .mark .parametrize ("K" , (256 , 5120 , 7168 )) 
524- def  test_triton_mxfp8_dequant_dim0 (M , K ):
525-     x  =  torch .zeros (M , K , dtype = torch .bfloat16 , device = "cuda" )
525+ @pytest .mark .parametrize ("orig_dtype" , (torch .float32 , torch .bfloat16 )) 
526+ def  test_triton_mxfp8_dequant_dim0 (M , K , orig_dtype ):
527+     x  =  torch .zeros (M , K , dtype = orig_dtype , device = "cuda" )
526528    block_size  =  32 
527529    x_data , x_scales  =  triton_to_mxfp8_dim0_reference (x , block_size = 32 )
528530    hp_ref  =  to_dtype (
529531        x_data ,
530532        x_scales ,
531533        torch .float8_e4m3fn ,
532534        block_size ,
533-         torch . bfloat16 ,
535+         orig_dtype ,
534536    )
535-     hp_t  =  triton_mxfp8_dequant_dim0 (x_data , x_scales , torch . bfloat16 , block_size )
537+     hp_t  =  triton_mxfp8_dequant_dim0 (x_data , x_scales , orig_dtype , block_size )
536538    torch .testing .assert_close (hp_t , hp_ref , rtol = 0 , atol = 0 )
537539
538540
0 commit comments