Skip to content

Commit 0b33455

Browse files
committed
Check numeric result with reference output
1 parent a666afa commit 0b33455

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

backends/qualcomm/tests/test_passes.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,7 @@ def test_mha_to_sha(self):
8080
mod = convert_linear_to_conv2d(LlamaAttention(0, args, True))
8181

8282
# Prepare inputs
83-
hidden_states = torch.randint(
84-
low=0,
85-
high=100,
86-
size=(args.max_batch_size, args.ar_len, args.dim),
87-
dtype=torch.float32,
88-
)
83+
hidden_states = torch.randn(args.max_batch_size, args.ar_len, args.dim)
8984
freqs_cos = torch.randn(args.ar_len, 1)
9085
freqs_sin = torch.randn(args.ar_len, 1)
9186
atten_mask = CausalAttentionMask(
@@ -113,6 +108,9 @@ def test_mha_to_sha(self):
113108
v_cache,
114109
)
115110

111+
# Run original module for reference
112+
refs = mod(*sample_input)
113+
116114
# Export the module and convert linear to conv2d
117115
edge_program = to_edge(torch.export.export(mod, sample_input))
118116
new_ep = edge_program.exported_program()
@@ -141,6 +139,16 @@ def test_mha_to_sha(self):
141139
# Check graph structure: WQ, WK, WV should be converted to SHA
142140
self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited")
143141

142+
# Execute new graph and compare with reference
143+
outs = graph_module(
144+
*new_ep.state_dict.values(), *new_ep.constants.values(), *sample_input
145+
)
146+
for i, (out, ref) in enumerate(zip(outs, refs)):
147+
self.assertTrue(
148+
torch.allclose(out, *ref, rtol=1e-6, atol=1e-6),
149+
f"Output {i} mismatch: got {out}, expected {ref}",
150+
)
151+
144152

145153
if __name__ == "__main__":
146154
unittest.main()

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self, weight, bias=None):
157157

158158
def forward(self, x):
159159
rank = x.dim()
160-
x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
160+
x = x.reshape(*x.shape, 1) if rank == 3 else x.reshape(1, *x.shape, 1)
161161
x = torch.transpose(x, 1, 2)
162162
res = self.conv(x)
163163
res = torch.transpose(res, 1, 2)

0 commit comments

Comments
 (0)