@@ -80,12 +80,7 @@ def test_mha_to_sha(self):
8080 mod = convert_linear_to_conv2d (LlamaAttention (0 , args , True ))
8181
8282 # Prepare inputs
83- hidden_states = torch .randint (
84- low = 0 ,
85- high = 100 ,
86- size = (args .max_batch_size , args .ar_len , args .dim ),
87- dtype = torch .float32 ,
88- )
83+ hidden_states = torch .randn (args .max_batch_size , args .ar_len , args .dim )
8984 freqs_cos = torch .randn (args .ar_len , 1 )
9085 freqs_sin = torch .randn (args .ar_len , 1 )
9186 atten_mask = CausalAttentionMask (
@@ -113,6 +108,9 @@ def test_mha_to_sha(self):
113108 v_cache ,
114109 )
115110
111+ # Run original module for reference
112+ refs = mod (* sample_input )
113+
116114 # Export the module and convert linear to conv2d
117115 edge_program = to_edge (torch .export .export (mod , sample_input ))
118116 new_ep = edge_program .exported_program ()
@@ -141,6 +139,16 @@ def test_mha_to_sha(self):
141139 # Check graph structure: WQ, WK, WV should be converted to SHA
142140 self .assertTrue (len (conv_nodes ) == 25 , "Convolution nodes should be splited" )
143141
142+ # Execute new graph and compare with reference
143+ outs = graph_module (
144+ * new_ep .state_dict .values (), * new_ep .constants .values (), * sample_input
145+ )
146+ for i , (out , ref ) in enumerate (zip (outs , refs )):
147+ self .assertTrue (
148+ torch .allclose (out , * ref , rtol = 1e-6 , atol = 1e-6 ),
149+ f"Output { i } mismatch: got { out } , expected { ref } " ,
150+ )
151+
144152
145153if __name__ == "__main__" :
146154 unittest .main ()
0 commit comments