-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdistiller.py
More file actions
39 lines (30 loc) · 1.34 KB
/
distiller.py
File metadata and controls
39 lines (30 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.nn.functional as F
def distillation_loss(source, target):
loss = torch.nn.functional.mse_loss(source, target, reduction="none")
return loss.sum()
class Distiller(nn.Module):
def __init__(self, args, t_net, s_net):
super(Distiller, self).__init__()
self.t_net = t_net
self.s_net = s_net
self.args = args
def forward(self, x):
""" t_feats, s_feats shape = [T, B, C, H, W] * feat_num """
t_feats, t_sp, t_out = self.t_net.extract_feature(x)
s_feats, s_sp, s_out = self.s_net.extract_feature(x)
feat_num = len(t_feats)
loss_distill = 0
time_window = s_feats[0].shape[0]
tem_1 = tem_2 = 6
for i in range(feat_num):
for j in range(time_window):
if self.args.loss == "CLM":
loss_distill += distillation_loss(s_feats[i][j], t_feats[i][j].detach()) / 2 ** (feat_num - i - 1)
elif self.args.loss == "CLK":
loss_distill += nn.KLDivLoss()(
F.log_softmax(s_feats[i][j].reshape(s_feats[i][j].shape[0], -1) / tem_1, dim=1),
F.softmax(t_feats[i][j].detach().reshape(t_feats[i][j].shape[0], -1) / tem_2, dim=1)
) * (tem_1 * tem_2)
return s_out, t_out, loss_distill