-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_fomo.py
More file actions
157 lines (127 loc) · 5.26 KB
/
train_fomo.py
File metadata and controls
157 lines (127 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import random
import torch
from torchvision.datasets import CocoDetection
from torchvision import transforms
from torchvision.models import MobileNet_V2_Weights
from torch.utils.data import Dataset, DataLoader, Subset
from models_architecture.models.fomo_mobilenetv2 import FOMOMobileNetV2
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std =[0.229, 0.224, 0.225])
])
# Set to your local COCO path
coco_root = "coco"
coco_val = CocoDetection(
root=os.path.join(coco_root, "val2017"),
annFile=os.path.join(coco_root, "annotations", "instances_val2017.json"),
transform=transform
)
subset_size = 5000
indices = list(range(len(coco_val)))
random.shuffle(indices)
subset = Subset(coco_val, indices[:subset_size])
# Split: 70% train, 15% val, 15% test
n = len(subset)
train_set = Subset(subset, list(range(0, int(0.7 * n))))
val_set = Subset(subset, list(range(int(0.7 * n), int(0.85 * n))))
test_set = Subset(subset, list(range(int(0.85 * n), n)))
def generate_heatmap_from_coco(boxes, labels, image_size=(224, 224), grid_size=8):
heatmap = torch.zeros((1, image_size[0] // grid_size, image_size[1] // grid_size))
for box, label in zip(boxes, labels):
if label != 1: # 1 = 'person' in COCO
continue
x, y, w, h = box
# Compute box corners
x0 = int(x // grid_size)
y0 = int(y // grid_size)
x1 = int((x + w) // grid_size)
y1 = int((y + h) // grid_size)
# Clamp to heatmap bounds
x0 = max(0, min(x0, heatmap.shape[2] - 1))
x1 = max(0, min(x1, heatmap.shape[2] - 1))
y0 = max(0, min(y0, heatmap.shape[1] - 1))
y1 = max(0, min(y1, heatmap.shape[1] - 1))
# Fill the heatmap region
heatmap[0, x0:x1, y0:y1] = 1.0
return heatmap
class FOMOCocoDataset(Dataset):
def __init__(self, coco_subset):
self.data = coco_subset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img, anns = self.data[idx]
boxes = [obj['bbox'] for obj in anns]
labels = [obj['category_id'] for obj in anns]
heatmap = generate_heatmap_from_coco(boxes, labels)
return img, heatmap
train_loader = DataLoader(FOMOCocoDataset(train_set), batch_size=16, shuffle=True)
val_loader = DataLoader(FOMOCocoDataset(val_set), batch_size=16)
test_loader = DataLoader(FOMOCocoDataset(test_set), batch_size=16)
model = FOMOMobileNetV2(num_classes=1, weights=MobileNet_V2_Weights.DEFAULT) # defined earlier
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.BCEWithLogitsLoss()
# Initialize variables
epoch_number = 100
best_val_loss = float('inf') # Start with a very large number
patience = 5 # How many epochs to wait before early stopping
counter = 0 # Counter to track patience
# Training loop (with validation loss tracking)
for epoch in range(epoch_number): # Running epochs
model.train()
total_train_loss = 0
for imgs, heatmaps in train_loader:
preds = model(imgs)
loss = criterion(preds, heatmaps)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss += loss.item()
# Validation
model.eval()
total_val_loss = 0
with torch.no_grad():
for imgs, heatmaps in val_loader:
preds = model(imgs)
loss = criterion(preds, heatmaps)
total_val_loss += loss.item()
print(f"Epoch {epoch+1} | Train Loss: {total_train_loss/subset_size:.4f} | Val Loss: {total_val_loss/subset_size:.4f}")
# Save the best model (if validation loss decreases)
if total_val_loss < best_val_loss:
print(f"Validation loss decreased: saving model...")
best_val_loss = total_val_loss
torch.save(model.state_dict(), "fomo_mobilenetv2_best.pth")
counter = 0 # Reset counter if improvement
else:
counter += 1
# Early stopping: stop training if no improvement for 'patience' epochs
if counter >= patience:
print("Early stopping: no improvement in validation loss for 5 epochs")
break
def evaluate_fomo(model, dataloader, threshold=0.5):
model.eval()
TP = FP = TN = FN = 0
with torch.no_grad():
for imgs, targets in dataloader:
outputs = torch.sigmoid(model(imgs))
preds = (outputs > threshold).float()
TP += ((preds == 1) & (targets == 1)).sum().item()
FP += ((preds == 1) & (targets == 0)).sum().item()
TN += ((preds == 0) & (targets == 0)).sum().item()
FN += ((preds == 0) & (targets == 1)).sum().item()
precision = TP / (TP + FP + 1e-6)
recall = TP / (TP + FN + 1e-6)
f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)
accuracy = (TP + TN) / (TP + TN + FP + FN)
print("\nEvaluation Metrics:")
print(f"Accuracy : {accuracy*100:.2f}%")
print(f"Precision: {precision:.4f}")
print(f"Recall : {recall:.4f}")
print(f"F1 Score : {f1_score:.4f}")
print(f"(TP={TP}, FP={FP}, TN={TN}, FN={FN})")
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1_score}
# Final evaluation
evaluate_fomo(model, test_loader, 0.3)