1111    blockwise_barrier ,
1212    sync_threads ,
1313)
14- from  torchao .prototype .mx_formats .config  import  ScaleCalculationMode 
14+ from  torchao .prototype .mx_formats .kernels  import  (
15+     triton_mxfp8_dequant_dim0 ,
16+     triton_to_mxfp8_dim0 ,
17+ )
1518from  torchao .prototype .mx_formats .mx_tensor  import  to_dtype , to_mx 
1619
1720
@@ -473,11 +476,9 @@ def forward(
473476        """ 
474477        # Quantize input 
475478        block_size  =  32 
476-         input_scales ,  input_data  =  to_mx (
479+         input_data ,  input_scales  =  triton_to_mxfp8_dim0 (
477480            input ,
478-             elem_dtype = torch .float8_e4m3fn ,
479-             block_size = block_size ,
480-             scaling_mode = ScaleCalculationMode .RCEIL ,
481+             inner_block_size = block_size ,
481482        )
482483
483484        # Dispatch data (async) 
@@ -501,20 +502,17 @@ def forward(
501502        output_data  =  torch .ops ._c10d_functional .wait_tensor (output_data )
502503
503504        # Dequantize output 
504-         lowp_dtype  =  output_data .dtype 
505505        hp_dtype  =  input .dtype 
506-         hp_output  =  to_dtype (
506+         triton_hp_output  =  triton_mxfp8_dequant_dim0 (
507507            output_data ,
508-             output_scales .view (torch .float8_e8m0fnu ),
509-             lowp_dtype ,
510-             block_size ,
508+             output_scales ,
511509            hp_dtype ,
510+             block_size ,
512511        )
513- 
514512        ctx .input_splits  =  input_splits 
515513        ctx .output_splits  =  output_splits 
516514        ctx .group  =  group 
517-         return  hp_output 
515+         return  triton_hp_output 
518516
519517    @staticmethod  
520518    def  backward (ctx , grad_output_hp ):
@@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):
529527
530528        # Quantize grad_output 
531529        block_size  =  32 
532-         grad_out_scales ,  grad_out_data  =  to_mx (
530+         grad_out_data ,  grad_out_scales  =  triton_to_mxfp8_dim0 (
533531            grad_output_hp ,
534-             elem_dtype = torch .float8_e4m3fn ,
535-             block_size = block_size ,
536-             scaling_mode = ScaleCalculationMode .RCEIL ,
532+             inner_block_size = block_size ,
537533        )
538534
539535        # Dispatch data (async) 
@@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp):
557553        grad_input_scales  =  torch .ops ._c10d_functional .wait_tensor (grad_input_scales )
558554
559555        hp_dtype  =  grad_output_hp .dtype 
560-         lowp_dtype  =  grad_input_data .dtype 
561-         grad_input_hp  =  to_dtype (
556+         grad_input_hp  =  triton_mxfp8_dequant_dim0 (
562557            grad_input_data ,
563-             grad_input_scales .view (torch .float8_e8m0fnu ),
564-             lowp_dtype ,
565-             block_size ,
558+             grad_input_scales ,
566559            hp_dtype ,
560+             block_size ,
567561        )
568562        return  grad_input_hp , None , None , None 
569563
0 commit comments