@@ -73,7 +73,8 @@ def test_mse_observer_symmetric_scale_range():
7373
7474
7575def  test_mse_fp4 ():
76-     tensor  =  torch .arange (24 , dtype = torch .bfloat16 ).reshape ((4 , 6 )) /  24 
76+     module  =  torch .nn .Linear (6 , 4 )
77+     module .weight .data  =  torch .arange (24 , dtype = torch .bfloat16 ).reshape ((4 , 6 )) /  24 
7778
7879    weights  =  QuantizationArgs (
7980        num_bits = 4 ,
@@ -84,8 +85,15 @@ def test_mse_fp4():
8485    )
8586
8687    observer  =  weights .observer 
87-     observer  =  Observer .load_from_registry (observer , base_name = "weight" , args = weights )
88-     scale , zero_point  =  observer (tensor )
88+     observer  =  Observer .load_from_registry (
89+         observer , base_name = "weight" , args = weights , module = module 
90+     )
8991
90-     qdq_tensor  =  fake_quantize (tensor , scale , zero_point , weights )
91-     assert  torch .nn .functional .mse_loss (qdq_tensor , tensor ) <=  0.002 
92+     global_scale  =  observer .get_global_scale (module .weight )
93+     module .weight_global_scale  =  global_scale 
94+     scale , zero_point  =  observer (module .weight )
95+ 
96+     qdq_tensor  =  fake_quantize (
97+         module .weight , scale , zero_point , weights , global_scale = global_scale 
98+     )
99+     assert  torch .nn .functional .mse_loss (qdq_tensor , module .weight ) <=  0.002 
0 commit comments