-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantization.py
More file actions
65 lines (47 loc) · 1.42 KB
/
quantization.py
File metadata and controls
65 lines (47 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch.autograd import Function
"""
Quantization functions for MD-SNN.
- w_q: Symmetric uniform quantization for weights and membrane potentials (Eq. 3)
- quantizeBN: BN parameter quantization following WAGEUBN (Eq. 4)
"""
bitsBN = 8 # BN parameters always quantized to 8-bit (Sec. IV-A2)
def delta(bits):
return 2. ** (1 - bits)
def clip(x, bits):
if bits >= 32:
return x
step = delta(bits)
ceil = 1 - step
floor = step - 1
return torch.clamp(x, floor, ceil)
def quant(x, bits):
if bits >= 32:
return x
return torch.round(x / delta(bits)) * delta(bits)
class QBN(Function):
"""STE-based quantization for Batch Normalization parameters."""
@staticmethod
def forward(self, x):
bits = bitsBN
if bits >= 32:
return x
return clip(quant(x, bits), bits)
@staticmethod
def backward(self, grad_output):
return grad_output
quantizeBN = QBN().apply
def w_q(w, b):
"""Symmetric uniform quantization with tanh transformation (Eq. 3).
Args:
w: Input tensor (weights or membrane potentials).
b: Target bit-width.
Returns:
Quantized tensor with STE gradient.
"""
w = torch.tanh(w)
alpha = w.data.abs().max()
w = torch.clamp(w / alpha, min=-1, max=1)
w = w * (2 ** (b - 1) - 1)
w_hat = (w.round() - w).detach() + w
return w_hat * alpha / (2 ** (b - 1) - 1)