-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
22 lines (16 loc) · 805 Bytes
/
train.py
File metadata and controls
22 lines (16 loc) · 805 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
def train_one_epoch(epoch_index, model, trainloader, optimizer, criterion, device):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # Set all gradient parameters to 0
outputs = model(inputs) # Forward pass
loss = criterion(outputs, labels) # Calculates the loss
loss.backward() # Backpropogate to find how much every neuron contributed to error
optimizer.step() # Updates the gradients found before
running_loss += loss.item()
if i % 10 == 9:
print(f'[Epoch {epoch_index + 1}, Batch {i + 1}] loss: {running_loss / 10:.3f}')
running_loss = 0.0