-
Notifications
You must be signed in to change notification settings - Fork 1
/
ff_training.py
88 lines (76 loc) · 2.51 KB
/
ff_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import time
import sys
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as dst
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
import model
def train(dataloader, net):
net.train()
total = 0
correct = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
output = net(x)
lossv = loss_f(output, y)
lossv.backward()
optimizer.step()
correct += y.eq(torch.max(output.data, 1)[1]).sum().item()
total += y.numel()
return correct / total
def test(dataloader, net):
net.eval()
total = 0
correct = 0
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
output = net(x)
correct += y.eq(torch.max(output.data, 1)[1]).sum().item()
total += y.numel()
return correct / total
if __name__ == "__main__":
# parameters
batchSize = 128
lr = 1e-4
model_path = 'models/vgg16.pth'
data_dir = 'data/cifar10/'
epochs = [80, 20]
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists('models'):
os.makedirs('models')
device = torch.device('cuda')
# model and loss
net = model.VGG()
loss_f = nn.CrossEntropyLoss()
net.to(device)
loss_f.to(device)
# data
transform_train = tfs.Compose([tfs.RandomCrop(32, padding=4),
tfs.RandomHorizontalFlip(),
tfs.ToTensor()])
data = dst.CIFAR10(data_dir, download=True, train=True,
transform=transform_train)
data_test = dst.CIFAR10(data_dir, download=True, train=False,
transform=tfs.Compose([tfs.ToTensor()]))
dataloader = DataLoader(data, batch_size=batchSize, shuffle=True)
dataloader_test = DataLoader(data_test, batch_size=batchSize, shuffle=False)
count = 0
for epoch in epochs:
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
for _ in range(epoch):
beg = time.time()
count += 1
train_acc = train(dataloader, net)
test_acc = test(dataloader_test, net)
run_time = time.time() - beg
print('Epoch {}, Time {:.2f}, Train: {:.5f}, Test: {:.5f}'.\
format(count, run_time, train_acc, test_acc))
sys.stdout.flush()
lr /= 10
torch.save(net.state_dict(), model_path)