-
Notifications
You must be signed in to change notification settings - Fork 0
/
start.py
35 lines (30 loc) · 1.14 KB
/
start.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import time
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from model.sbnet import NMSLBlock,SBNet
import argparse
args = None
def main():
model=SBNet(args.bnum,args.inpf).cuda()
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
t0=time.perf_counter()
while time.perf_counter()-t0<args.time:
inp = torch.rand(args.batch, 3, 100, 100).cuda()
tuth = torch.randint(low=0,high=9,size=[args.batch]).cuda()
optimizer.zero_grad()
x=model(inp)
loss = criterion(x, tuth)
loss.backward()
optimizer.step()
if __name__=='__main__':
parser = argparse.ArgumentParser(description='gogogo')
parser.add_argument('--bnum',default=10, type=int, help='模型的block数量')
parser.add_argument('--inpf',default=1000, type=int, help='参与计算的feature通道数')
parser.add_argument('--time',default=20, type=int, help='程序运行时间(s)')
parser.add_argument('--batch', default=8, type=int, help='生成的input_batch')
args = parser.parse_args()
main()