Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to torch1.6 #558

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions models/hub/yolov3-spp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,25 @@ backbone:

# YOLOv3-SPP head
head:
[[-1, 1, Bottleneck, [1024, False]], # 11
[[-1, 1, Bottleneck, [1024, False]],
[-1, 1, SPP, [512, [5, 9, 13]]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, Conv, [512, 1, 1]],
[-1, 1, Conv, [1024, 3, 1]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P5/32-large)
[-1, 1, Conv, [1024, 3, 1]], # 15

[-3, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 8], 1, Concat, [1]], # cat backbone P4
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Bottleneck, [512, False]],
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, Conv, [512, 3, 1]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 24 (P4/16-medium)
[-1, 1, Conv, [512, 3, 1]], # 22

[-3, 1, Conv, [128, 1, 1]],
[-1, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P3
[-1, 1, Bottleneck, [256, False]],
[-1, 2, Bottleneck, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 30 (P3/8-small)
[-1, 2, Bottleneck, [256, False]], # 27

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
15 changes: 6 additions & 9 deletions models/hub/yolov5-fpn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,17 @@ backbone:

# YOLOv5 FPN head
head:
[[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 11 (P5/32-large)
[[-1, 3, BottleneckCSP, [1024, False]], # 10

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 1, Conv, [512, 1, 1]],
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 16 (P4/16-medium)
[-1, 3, BottleneckCSP, [512, False]], # 14

[-2, 1, nn.Upsample, [None, 2, 'nearest']],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 1, Conv, [256, 1, 1]],
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 21 (P3/8-small)
[-1, 3, BottleneckCSP, [256, False]], # 18

[[], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
[[18, 14, 10], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
24 changes: 10 additions & 14 deletions models/hub/yolov5-panet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,28 @@ backbone:
[-1, 9, BottleneckCSP, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, BottleneckCSP, [1024, False]], # 9
]

# YOLOv5 PANet head
head:
[[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, Conv, [512, 1, 1]], # 10

[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, Conv, [256, 1, 1]], # 14
[-1, 3, BottleneckCSP, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, BottleneckCSP, [256, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 18 (P3/8-small)
[-1, 3, BottleneckCSP, [256, False]], # 17

[-2, 1, Conv, [256, 3, 2]],
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, BottleneckCSP, [512, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 22 (P4/16-medium)
[-1, 3, BottleneckCSP, [512, False]], # 20

[-2, 1, Conv, [512, 3, 2]],
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, BottleneckCSP, [1024, False]],
[-1, 1, nn.Conv2d, [na * (nc + 5), 1, 1]], # 26 (P5/32-large)
[-1, 3, BottleneckCSP, [1024, False]], # 23

[[], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P5, P4, P3)
]
5 changes: 4 additions & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
from copy import deepcopy
from torch.cuda import amp

from models.experimental import *

Expand Down Expand Up @@ -78,7 +79,8 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None): # model, input channels,
torch_utils.initialize_weights(self)
self.info()
print('')


@amp.autocast()
def forward(self, x, augment=False, profile=False):
if augment:
img_size = x.shape[-2:] # height, width
Expand All @@ -99,6 +101,7 @@ def forward(self, x, augment=False, profile=False):
else:
return self.forward_once(x, profile) # single-scale inference, train

@amp.autocast()
def forward_once(self, x, profile=False):
y, dt = [], [] # outputs
for m in self.model:
Expand Down
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ scipy
tqdm
# pycocotools>=2.0

# Nvidia Apex (optional) for mixed precision training --------------------------
# git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . --user && cd .. && rm -rf apex

# Conda commands (in place of pip) ---------------------------------------------
# conda update -yn base -c defaults conda
# conda install -yc anaconda numpy opencv matplotlib tqdm pillow ipython
Expand Down
49 changes: 22 additions & 27 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
from utils.datasets import *
from utils.utils import *

mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

# Hyperparameters
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
Expand Down Expand Up @@ -154,10 +147,6 @@ def train(hyp, tb_writer, opt, device):

del ckpt

# Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
Expand Down Expand Up @@ -227,6 +216,10 @@ def train(hyp, tb_writer, opt, device):
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)

# Creates a GradScaler once at the beginning of training.
scaler = amp.GradScaler()

# torch.autograd.set_detect_anomaly(True)
for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
model.train()
Expand Down Expand Up @@ -285,34 +278,36 @@ def train(hyp, tb_writer, opt, device):
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
# Mixed precision training
with amp.autocast():
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results

# Backward
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()

# Optimize
if ni % accumulate == 0:
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if ema is not None:
ema.update(model)

# Print
if rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)
Expand Down