-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
def test(args,obj_name, model,anomaly_names):
model.eval()
dataset = MVTec_classification_test(args,obj_name,anomaly_names)
dataloader = DataLoader(dataset, batch_size=100,
shuffle=False, num_workers=0)
for i_batch, sample_batched in enumerate(dataloader):
image, label = sample_batched
image = image.cuda()
label = label.cuda()
y_pred = model(image)
prediction = torch.argmax(y_pred, 1)
correct = (prediction == label).sum().float()
print("Accuracy: %.4f"%(correct/len(label)))
return correct/len(label)
返回的是最后一个batch的准确率
Metadata
Metadata
Assignees
Labels
No labels