Skip to content

Commit e3f24d4

Browse files
committed
fix shapes
Signed-off-by: Kyle Sayers <[email protected]>
1 parent bf1b9ba commit e3f24d4

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

tests/mock_observer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def flatten_for_quantization(
7777

7878

7979
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
80+
# value.shape = (num_rows, num_cols)
81+
8082
if args.strategy == QuantizationStrategy.TENSOR:
8183
# (1, 1, num_weight_elems)
8284
return value.reshape((1, 1, -1))
@@ -117,6 +119,8 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
117119

118120

119121
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
122+
# value.shape = (batch_size, seq_len, hidden_dim)
123+
120124
if args.strategy == QuantizationStrategy.TENSOR:
121125
# (batch_size * seq_len, 1, hidden_dim)
122126
return value.reshape((-1, 1, value.size(-1)))
@@ -144,10 +148,11 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
144148

145149

146150
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
151+
# value.shape = (batch_size, num_heads, seq_len, head_dim)
152+
147153
if args.strategy == QuantizationStrategy.TENSOR:
148-
# (batch_size, seq_len, num_heads, head_dim)
149154
# (batch_size * seq_len, 1, num_heads * head_dim)
150-
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
155+
return value.transpose(1, 2).flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
151156

152157
if args.strategy == QuantizationStrategy.TOKEN:
153158
raise ValueError("Token quantization cannot be applied to attention")

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,24 @@ class MockAttention(torch.nn.Module):
287287
strategy="tensor",
288288
),
289289
torch.tensor([0.0]),
290-
torch.tensor([11.0]),
290+
torch.tensor([23.0]),
291291
torch.tensor(
292292
[
293293
[
294-
[[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]],
295-
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
294+
[
295+
[0.0000, 0.0000, 3.0625, 3.0625],
296+
[3.0625, 6.1250, 6.1250, 6.1250],
297+
[9.1875, 9.1875, 9.1875, 12.2500],
298+
],
299+
[
300+
[12.2500, 12.2500, 15.3125, 15.3125],
301+
[15.3125, 18.3750, 18.3750, 18.3750],
302+
[21.5000, 21.5000, 21.5000, 21.5000],
303+
],
296304
]
297305
]
298306
),
299-
0.19,
307+
0.81,
300308
),
301309
# static token is not supported
302310
# channel is not supported
@@ -310,35 +318,45 @@ class MockAttention(torch.nn.Module):
310318
symmetric=True,
311319
strategy="attn_head",
312320
),
313-
torch.tensor([[[0.0]], [[6.0]]]),
314-
torch.tensor([[[5.0]], [[11.0]]]),
321+
torch.tensor([[[0.0]], [[12.0]]]),
322+
torch.tensor([[[11.0]], [[23.0]]]),
315323
torch.tensor(
316324
[
317325
[
318-
[[0.0000, 1.3359, 2.0000], [2.6719, 4.0000, 4.6875]],
319-
[[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]],
326+
[
327+
[0.0000, 1.4688, 1.4688, 2.9375],
328+
[4.4062, 4.4062, 5.8750, 7.3438],
329+
[7.3438, 8.8125, 10.2500, 10.2500],
330+
],
331+
[
332+
[12.2500, 12.2500, 15.3125, 15.3125],
333+
[15.3125, 18.3750, 18.3750, 18.3750],
334+
[21.5000, 21.5000, 21.5000, 21.5000],
335+
],
320336
]
321337
]
322338
),
323-
0.13,
339+
0.55,
324340
),
325341
],
326342
)
327343
def test_static_attention_quantization(
328344
args, exp_min_val, exp_max_val, exp_quant, exp_loss
329345
):
330346
"""
331-
input = tensor([[[[ 0., 1., 2.],
332-
[ 3., 4., 5.]],
347+
input = tensor([[[[ 0., 1., 2., 3.],
348+
[ 4., 5., 6., 7.],
349+
[ 8., 9., 10., 11.]],
333350
334-
[[ 6., 7., 8.],
335-
[ 9., 10., 11.]]]])
351+
[[12., 13., 14., 15.],
352+
[16., 17., 18., 19.],
353+
[20., 21., 22., 23.]]]])
336354
"""
337-
# set up activation (and identity weight)
338-
batch_size, num_heads, seq_len, head_dim = 1, 2, 2, 3
355+
# set up attention
356+
batch_size, num_heads, seq_len, head_dim = 1, 2, 3, 4
339357
input = torch.arange(
340-
(batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16
341-
).reshape((batch_size, seq_len, num_heads, head_dim))
358+
(batch_size * num_heads * seq_len * head_dim), dtype=torch.bfloat16
359+
).reshape((batch_size, num_heads, seq_len, head_dim))
342360
attention = MockAttention()
343361

344362
# initialize quantization parameters
@@ -366,7 +384,5 @@ def test_static_attention_quantization(
366384
assert torch.equal(attention.k_observer.max_vals, exp_max_val)
367385

368386
# check forward pass
369-
print(output)
370-
print(torch.nn.functional.mse_loss(output, input))
371387
assert torch.allclose(output, exp_quant.to(output.dtype))
372388
assert torch.nn.functional.mse_loss(output, input) <= exp_loss

0 commit comments

Comments
 (0)