@@ -287,16 +287,24 @@ class MockAttention(torch.nn.Module):
287287                strategy = "tensor" , 
288288            ), 
289289            torch .tensor ([0.0 ]), 
290-             torch .tensor ([11 .0 
290+             torch .tensor ([23 .0 
291291            torch .tensor ( 
292292                [ 
293293                    [ 
294-                         [[0.0000 , 1.4688 , 1.4688 ], [2.9375 , 4.4062 , 4.4062 ]], 
295-                         [[5.8750 , 7.3438 , 7.3438 ], [8.8125 , 10.2500 , 10.2500 ]], 
294+                         [ 
295+                             [0.0000 , 0.0000 , 3.0625 , 3.0625 ], 
296+                             [3.0625 , 6.1250 , 6.1250 , 6.1250 ], 
297+                             [9.1875 , 9.1875 , 9.1875 , 12.2500 ], 
298+                         ], 
299+                         [ 
300+                             [12.2500 , 12.2500 , 15.3125 , 15.3125 ], 
301+                             [15.3125 , 18.3750 , 18.3750 , 18.3750 ], 
302+                             [21.5000 , 21.5000 , 21.5000 , 21.5000 ], 
303+                         ], 
296304                    ] 
297305                ] 
298306            ), 
299-             0.19  , 
307+             0.81  , 
300308        ), 
301309        # static token is not supported  
302310        # channel is not supported  
@@ -310,35 +318,45 @@ class MockAttention(torch.nn.Module):
310318                symmetric = True , 
311319                strategy = "attn_head" , 
312320            ), 
313-             torch .tensor ([[[0.0 ]], [[6 .0 
314-             torch .tensor ([[[5 .011 .0 
321+             torch .tensor ([[[0.0 ]], [[12 .0 
322+             torch .tensor ([[[11 .023 .0 
315323            torch .tensor ( 
316324                [ 
317325                    [ 
318-                         [[0.0000 , 1.3359 , 2.0000 ], [2.6719 , 4.0000 , 4.6875 ]], 
319-                         [[5.8750 , 7.3438 , 7.3438 ], [8.8125 , 10.2500 , 10.2500 ]], 
326+                         [ 
327+                             [0.0000 , 1.4688 , 1.4688 , 2.9375 ], 
328+                             [4.4062 , 4.4062 , 5.8750 , 7.3438 ], 
329+                             [7.3438 , 8.8125 , 10.2500 , 10.2500 ], 
330+                         ], 
331+                         [ 
332+                             [12.2500 , 12.2500 , 15.3125 , 15.3125 ], 
333+                             [15.3125 , 18.3750 , 18.3750 , 18.3750 ], 
334+                             [21.5000 , 21.5000 , 21.5000 , 21.5000 ], 
335+                         ], 
320336                    ] 
321337                ] 
322338            ), 
323-             0.13  , 
339+             0.55  , 
324340        ), 
325341    ], 
326342) 
327343def  test_static_attention_quantization (
328344    args , exp_min_val , exp_max_val , exp_quant , exp_loss 
329345):
330346    """ 
331-     input = tensor([[[[ 0.,  1.,  2.], 
332-                       [ 3.,  4.,  5.]], 
347+     input = tensor([[[[ 0.,  1.,  2.,  3.], 
348+                       [ 4.,  5.,  6.,  7.], 
349+                       [ 8.,  9., 10., 11.]], 
333350
334-                       [[ 6.,  7.,  8.], 
335-                       [ 9., 10., 11.]]]]) 
351+                      [[12., 13., 14., 15.], 
352+                       [16., 17., 18., 19.], 
353+                       [20., 21., 22., 23.]]]]) 
336354    """ 
337-     # set up activation (and identity weight)  
338-     batch_size , num_heads , seq_len , head_dim  =  1 , 2 , 2 ,  3 
355+     # set up attention  
356+     batch_size , num_heads , seq_len , head_dim  =  1 , 2 , 3 ,  4 
339357    input  =  torch .arange (
340-         (batch_size  *  seq_len  *  num_heads  *  head_dim ), dtype = torch .bfloat16 
341-     ).reshape ((batch_size , seq_len ,  num_heads , head_dim ))
358+         (batch_size  *  num_heads  *  seq_len  *  head_dim ), dtype = torch .bfloat16 
359+     ).reshape ((batch_size , num_heads ,  seq_len , head_dim ))
342360    attention  =  MockAttention ()
343361
344362    # initialize quantization parameters 
@@ -366,7 +384,5 @@ def test_static_attention_quantization(
366384        assert  torch .equal (attention .k_observer .max_vals , exp_max_val )
367385
368386    # check forward pass 
369-     print (output )
370-     print (torch .nn .functional .mse_loss (output , input ))
371387    assert  torch .allclose (output , exp_quant .to (output .dtype ))
372388    assert  torch .nn .functional .mse_loss (output , input ) <=  exp_loss 
0 commit comments