-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
47 lines (35 loc) · 1.63 KB
/
utils.py
File metadata and controls
47 lines (35 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
def replicate_input(x):
return x.detach().clone()
def to_one_hot(y, num_classes=10):
"""
Take a batch of label y with n dims and convert it to
1-hot representation with n+1 dims.
Link: https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/24
"""
y = replicate_input(y).view(-1, 1)
y_one_hot = y.new_zeros((y.size()[0], num_classes)).scatter_(1, y, 1)
return y_one_hot
def log_success_indices(clean_x_test, clean_y_test, device, logger, model, x_adv, wandb):
a_list = []
batch_size = 1
with torch.no_grad():
for counter in range(clean_x_test.shape[0]):
x_adv_PET_curr = x_adv[counter * batch_size:(counter + 1) *
batch_size].to(device)
y_curr = clean_y_test[counter * batch_size:(counter + 1) *
batch_size].to(device)
output_pet = model(x_adv_PET_curr)
pet_suc = output_pet.max(1)[1] != y_curr
# print(f"original label {y_curr}, adversarial new label {output_pet.max(1)[1]}")
if pet_suc:
a_list.append(counter)
# print(counter)
# print(f"attack succeeded for indices {a_list}")
logger.info(counter)
logger.info(f"attack succeeded for indices {a_list}")
asr = len(a_list) / clean_x_test.shape[0]
logger.info(f"ASR: {asr}, RA: {1 - asr}")
wandb.log({"ASR": asr, "RA": 1-asr, "success_indices": a_list})
table = wandb.Table(data=a_list, columns=["List of successfully attacked indices"])
wandb.log({"successful indices": table})