3030    _is_fbgemm_gpu_genai_available ,
3131    is_sm_at_least_89 ,
3232    is_sm_at_least_90 ,
33+     is_sm_at_least_100 ,    
3334    torch_version_at_least ,
3435)
3536
@@ -49,6 +50,28 @@ def forward(self, x):
4950        return  x 
5051
5152
53+ class  ToyConvModel (torch .nn .Module ):
54+     def  __init__ (
55+         self , dim , in_channels , out_channels , kernel_size , bias , padding , dtype , device 
56+     ):
57+         super ().__init__ ()
58+         convs  =  {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
59+         self .conv  =  convs [dim ](
60+             in_channels ,
61+             out_channels ,
62+             kernel_size ,
63+             bias = bias ,
64+             padding = padding ,
65+             dtype = dtype ,
66+             device = device ,
67+         )
68+         if  dim  ==  3 :
69+             self .conv  =  self .conv .to (memory_format = torch .channels_last_3d )
70+ 
71+     def  forward (self , x ):
72+         return  self .conv (x )
73+ 
74+ 
5275# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations 
5376@unittest .skipIf (not  torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" ) 
5477@unittest .skipIf (not  torch .cuda .is_available (), "Need CUDA available" ) 
@@ -148,6 +171,90 @@ def test_fp8_linear_variants(
148171                f"Quantization error is too high got a SQNR of { error }  
149172            )
150173
174+     @unittest .skipIf (not  torch .cuda .is_available (), "Need CUDA available" ) 
175+     @unittest .skipIf ( 
176+         not  is_sm_at_least_100 (), "Requires GPU with compute capability >= 8.9"  
177+     ) 
178+     @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ]) 
179+     @common_utils .parametrize ("compile" , [True , False ]) 
180+     @common_utils .parametrize ("granularity" , [PerTensor ()]) 
181+     @common_utils .parametrize ( 
182+         "kernel_preference" , 
183+         [KernelPreference .AUTO ], 
184+     ) 
185+     # only test for 3D conv for now 
186+     # Inputs are (N, C_in, C_out, D, H, W) 
187+     @common_utils .parametrize ( 
188+         "sizes" , 
189+         [ 
190+             (4 , 16 , 64 , 32 , 32 , 32 ), 
191+         ], 
192+     ) 
193+     def  test_fp8_conv_variants (
194+         self ,
195+         dtype : torch .dtype ,
196+         compile : bool ,
197+         granularity ,
198+         kernel_preference : KernelPreference ,
199+         sizes : Tuple ,
200+     ):
201+         if  (
202+             isinstance (granularity , PerTensor )
203+             and  kernel_preference  ==  KernelPreference .FBGEMM 
204+         ):
205+             return  unittest .skip (
206+                 "per tensor with fbgemm kernel preferece does not work yet" 
207+             )
208+ 
209+         if  kernel_preference  ==  KernelPreference .FBGEMM  and  (
210+             (not  _is_fbgemm_gpu_genai_available ()) or  (not  is_sm_at_least_90 ())
211+         ):
212+             return  unittest .skip (
213+                 "Requires fbgemm_gpu_genai to run fbgemm kernel preference test" 
214+             )
215+ 
216+         dim  =  3 
217+         N , C_in , C_out , D , H , W  =  sizes 
218+         kernel_size  =  3 
219+ 
220+         # Note: this is channel last memory format 
221+         input_tensor  =  torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
222+         input_tensor  =  input_tensor .to (memory_format = torch .channels_last_3d )
223+ 
224+         # Create a linear layer with bfloat16 dtype 
225+         model  =  ToyConvModel (
226+             dim ,
227+             C_in ,
228+             C_out ,
229+             kernel_size ,
230+             bias = False ,
231+             padding = 0 ,
232+             dtype = dtype ,
233+             device = "cuda" ,
234+         ).eval ()
235+ 
236+         quantized_model  =  copy .deepcopy (model )
237+ 
238+         config  =  Float8DynamicActivationFloat8WeightConfig (
239+             granularity = granularity ,
240+             kernel_preference = kernel_preference ,
241+         )
242+ 
243+         _is_conv3d  =  lambda  m , fqn : isinstance (m , torch .nn .Conv3d )
244+ 
245+         quantize_ (quantized_model , config , filter_fn = _is_conv3d )
246+ 
247+         if  compile :
248+             quantized_model  =  torch .compile (quantized_model , fullgraph = True )
249+ 
250+         output_original  =  model (input_tensor )
251+         output_quantized  =  quantized_model (input_tensor )
252+ 
253+         error  =  compute_error (output_original , output_quantized )
254+         assert  compute_error (output_original , output_quantized ) >  20 , (
255+             f"Quantization error is too high got a SQNR of { error }  
256+         )
257+ 
151258    @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()]) 
152259    @unittest .skipIf ( 
153260        not  is_sm_at_least_90 (), 
0 commit comments