-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
85 lines (63 loc) · 3 KB
/
test.py
File metadata and controls
85 lines (63 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import multiprocessing
import os
import dsntnn
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from handlers.data_loaders.loader import get_test_loader
from network.coordinate_network import CoordinatesRegressionNetwork
def test(device, network, input_size, batch_size, num_threads):
loader = get_test_loader(input_size, batch_size, num_threads)
sample_average_loss = []
sample_coordinates_loss = []
sample_heatmaps_loss = []
with torch.no_grad():
for _, sample in enumerate(tqdm(loader)):
images, joints = sample['image'].to(device), sample['joints'].to(device)
coordinates, heatmaps = network(images)
euclidian_loss, regularization_loss, average_loss = calculate_losses(coordinates, heatmaps, joints)
del sample, images, joints, coordinates, heatmaps
sample_average_loss.append(average_loss.item())
sample_coordinates_loss.append(torch.mean(euclidian_loss).item())
sample_heatmaps_loss.append(torch.mean(regularization_loss).item())
return sample_average_loss, sample_coordinates_loss, sample_heatmaps_loss
def create_arg_parser():
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument('-model', type=str)
argument_parser.add_argument('--input_size', type=int, default=224)
argument_parser.add_argument('--batch_size', type=int, default=32)
argument_parser.add_argument('--t7', type=str, default="")
return argument_parser
def create_pytorch_device():
device = torch.device("cuda:0")
num_threads = (multiprocessing.cpu_count() // 2)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.enabled = True
cudnn.benchmark = True
return device, num_threads
def create_network(device, model, t7):
net = CoordinatesRegressionNetwork(16, model).to(device)
net = torch.nn.DataParallel(net).to(device)
if t7 != "":
pre_trained = torch.load(t7)
net.module.load_state_dict(pre_trained)
for param in list(net.parameters()):
param.requires_grad = True
return net
def calculate_losses(coordinates, heatmaps, joints):
euclidian_loss = dsntnn.euclidean_losses(coordinates, joints)
regularization_loss = dsntnn.js_reg_losses(heatmaps, joints, 1.0)
average_loss = dsntnn.average_loss(euclidian_loss + regularization_loss)
return euclidian_loss, regularization_loss, average_loss
if __name__ == "__main__":
parser = create_arg_parser()
args = parser.parse_args()
device, num_threads = create_pytorch_device()
net = create_network(device, args.model, args.t7)
sample_average_loss, sample_coordinates_loss, sample_heatmaps_loss = \
test(device, net, args.input_size, args.batch_size, num_threads)
print("Average loss: " + np.mean(np.array(sample_average_loss)))
print("Coordinates loss: " + np.mean(np.array(sample_coordinates_loss)))
print("Heatmaps loss:" + np.mean(np.array(sample_heatmaps_loss)))