-
Notifications
You must be signed in to change notification settings - Fork 4
/
NetworkFunction.py
91 lines (72 loc) · 2.71 KB
/
NetworkFunction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from torch import nn
import torch
from tqdm import tqdm
from utils import *
def mp_test(test_dataloader, model, net_arch, presim_len, sim_len, device):
new_tot = torch.zeros(sim_len).cuda(device)
model = model.cuda(device)
model.eval()
with torch.no_grad():
for img, label in tqdm(test_dataloader):
new_spikes = 0
img = img.cuda(device)
label = label.cuda(device)
for t in range(presim_len+sim_len):
out = model(img)
if t >= presim_len:
new_spikes += out
new_tot[t-presim_len] += (label==new_spikes.max(1)[1]).sum().item()
return new_tot
def train_ann(train_dataloader, test_dataloader, model, epochs, lr, wd, device, save_name):
model = model.cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
loss_fn = nn.CrossEntropyLoss()
best_acc = 0
for epoch in range(epochs):
epoch_loss = 0
lenth = 0
model.train()
for img, label in tqdm(train_dataloader):
img = img.cuda(device)
label = label.cuda(device)
optimizer.zero_grad()
out = model(img)
loss = loss_fn(out, label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
lenth += len(img)
acc = eval_ann(test_dataloader, model, device)
print(f"ANNs training Epoch {epoch}: Val_loss: {epoch_loss/lenth} Acc: {acc}")
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), save_name)
scheduler.step()
return model
def eval_ann(test_dataloader, model, device):
tot = 0
model.eval()
model.cuda(device)
with torch.no_grad():
for img, label in tqdm(test_dataloader):
img = img.cuda(device)
label = label.cuda(device)
out = model(img)
tot += (label==out.max(1)[1]).sum().item()
return tot
def eval_snn(test_dataloader, model, sim_len, device):
tot = torch.zeros(sim_len).cuda(device)
model = model.cuda(device)
model.eval()
with torch.no_grad():
for img, label in tqdm(test_dataloader):
spikes = 0
img = img.cuda(device)
label = label.cuda(device)
for t in range(sim_len):
out = model(img)
spikes += out
tot[t] += (label==spikes.max(1)[1]).sum().item()
reset_net(model)
return tot