Skip to content

train_classification.py中如何应对测试集中超过100个样本的情况? #107

@HU13-SVG

Description

@HU13-SVG

def test(args,obj_name, model,anomaly_names):
model.eval()

dataset = MVTec_classification_test(args,obj_name,anomaly_names)
dataloader = DataLoader(dataset, batch_size=100,
                        shuffle=False, num_workers=0)

for i_batch, sample_batched in enumerate(dataloader):
    image, label = sample_batched
    image = image.cuda()
    label = label.cuda()
    y_pred = model(image)
    prediction = torch.argmax(y_pred, 1)
    correct = (prediction == label).sum().float()
    print("Accuracy: %.4f"%(correct/len(label)))
return correct/len(label)

返回的是最后一个batch的准确率

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions