-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
111 lines (94 loc) · 3.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as FT
from tqdm import tqdm
from torch.utils.data import DataLoader
from model import YoloV1
from dataset import VOCDataset
from loss import YoloLoss
from utils import (
intersection_over_union,
non_max_suppression,
mean_average_precision,
cellboxes_to_boxes,
get_bboxes,
plot_image,
save_checkpoint,
load_checkpoint
)
seed = 3301 #pseudorandom seed, gets the same dataset loading
torch.manual_seed(seed)
# Hyperparameters for our model
LEARNING_RATE = 2e-5
DEVICE = "cude" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
WEIGHT_DECAY = 0 # no regularization in order for fast training
EPOCHS = 100
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = False
LOAD_MODEL_FILE = "overfit.pth.tar"
IMG_DIR = "data/images"
LABEL_DIR = "data/labels"
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes):
for t in self.transforms:
img, bboxes = t(img), bboxes
return img, bboxes
transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor()])
def train_fn(train_loader, model, optimizer, loss_fn):
loop = tqdm(train_loader, leave = True)
mean_loss = []
for batch_idx, (x,y) in enumerate(loop):
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x)
loss = loss_fn(out, y)
mean_loss.append(loss.item())
optimizer.step()
#update progress
loop.set_postfix(loss = loss.item())
print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")
def main():
model = YoloV1(split_size = 7, num_boxes = 2, num_classes = 20).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)
loss_fn = YoloLoss()
if LOAD_MODEL:
load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
train_dataset = VOCDataset(
"data/8examples.csv",
transform = transform,
img_dir = IMG_DIR,
label_dir = LABEL_DIR
)
test_dataset = VOCDataset(
"data/test.csv",
transform = transform,
img_dir = IMG_DIR,
label_dir = LABEL_DIR
)
train_loader = DataLoader(
dataset = train_dataset,
batch_size = BATCH_SIZE,
num_workers= NUM_WORKERS,
pin_memory= PIN_MEMORY,
shuffle = True,
drop_last = False #We have 8 examples, so false, but true if more than BATCH_SIZE
)
for epoch in range(EPOCHS):
pred_boxes, target_boxes = get_bboxes(train_loader, model, iou_threshold = 0.5, threshold = 0.4)
mean_avg_prec = mean_average_precision(pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint")
print(f"Train mAP: {mean_avg_prec}")
if mean_avg_prec > 0.9:
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
save_checkpoint(checkpoint, filename= LOAD_MODEL_FILE)
import time
time.sleep(10)
train_fn(train_loader, model, optimizer, loss_fn)
if __name__ == "__main__":
main()