-
Notifications
You must be signed in to change notification settings - Fork 15
/
cifar100_classif.py
72 lines (67 loc) · 3.24 KB
/
cifar100_classif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
from utils.util import AverageMeter, accuracy
from data.cifarloader import CIFAR100Loader
from models.vgg import VGG
import os
def train(model, train_loader, args):
optimizer = Adam(model.parameters(), lr=args.lr)
exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
criterion=nn.CrossEntropyLoss().cuda(device)
for epoch in range(args.epochs):
loss_record = AverageMeter()
acc_record = AverageMeter()
model.train()
exp_lr_scheduler.step()
for batch_idx, (x, label, _) in enumerate(train_loader):
x, target = x.to(device), label.to(device)
optimizer.zero_grad()
_, output= model(x)
loss = criterion(output, target)
acc = accuracy(output, target)
loss.backward()
optimizer.step()
acc_record.update(acc[0].item(), x.size(0))
loss_record.update(loss.item(), x.size(0))
print('Train Epoch: {} Avg Loss: {:.4f} \t Avg Acc: {:.4f}'.format(epoch, loss_record.avg, acc_record.avg))
test(model, eva_loader, args)
torch.save(model.state_dict(), args.model_dir)
print("model saved to {}.".format(args.model_dir))
def test(model, test_loader, args):
model.eval()
acc_record = AverageMeter()
for batch_idx, (x, label, _) in enumerate(test_loader):
x, target = x.to(device), label.to(device)
_, output= model(x)
acc = accuracy(output, target)
acc_record.update(acc[0].item(), x.size(0))
print('Test: Avg Acc: {:.4f}'.format(acc_record.avg))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description='cls',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--epochs', default=180, type=int)
parser.add_argument('--milestones', default=[100, 150], type=int, nargs='+')
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--num_classes', default=80, type=int)
parser.add_argument('--model_name', type=str, default='vgg6_cifar100_classif_80')
parser.add_argument('--dataset_root', type=str, default='data/datasets/CIFAR/')
parser.add_argument('--exp_root', type=str, default='./data/experiments/')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
runner_name = os.path.basename(__file__).split(".")[0]
model_dir= args.exp_root + '{}'.format(runner_name)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
args.model_dir = model_dir+'/'+args.model_name+'.pth'
train_loader = CIFAR100Loader(root=args.dataset_root, batch_size=128, split='train',labeled = True, aug='once', shuffle=True)
eva_loader = CIFAR100Loader(root=args.dataset_root, batch_size=128, split='test', labeled = True, aug=None, shuffle=False)
model = VGG(n_layer='5+1', out_dim=args.num_classes).to(device)
train(model, train_loader, args)
test(model, eva_loader, args)