Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch 1.6.0 update with native AMP #573

Merged
merged 7 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 36 additions & 44 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -14,13 +15,6 @@
from utils.datasets import *
from utils.utils import *

mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

# Hyperparameters
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
Expand Down Expand Up @@ -63,6 +57,7 @@ def train(hyp, tb_writer, opt, device):
yaml.dump(vars(opt), f, sort_keys=False)

# Configure
cuda = device.type != 'cpu'
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
Expand Down Expand Up @@ -113,7 +108,7 @@ def train(hyp, tb_writer, opt, device):
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
del pg0, pg1, pg2

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
Expand Down Expand Up @@ -160,24 +155,20 @@ def train(hyp, tb_writer, opt, device):

del ckpt

# Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

# DP mode
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
if cuda and rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# SyncBatchNorm
if opt.sync_bn and device.type != 'cpu' and rank != -1:
if opt.sync_bn and cuda and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
print('Using SyncBatchNorm()')

# Exponential moving average
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and rank != -1:
if cuda and rank != -1:
model = DDP(model, device_ids=[rank], output_device=rank)

# Trainloader
Expand Down Expand Up @@ -223,6 +214,7 @@ def train(hyp, tb_writer, opt, device):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
if rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
Expand All @@ -232,15 +224,14 @@ def train(hyp, tb_writer, opt, device):
model.train()

# Update image weights (optional)
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if dataset.image_weights:
# Generate indices.
# Generate indices
if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
k=dataset.n) # rand weighted idx
# Broadcast.
# Broadcast if DDP
if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int)
if rank == 0:
Expand All @@ -263,7 +254,7 @@ def train(hyp, tb_writer, opt, device):
optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0

# Warmup
if ni <= nw:
Expand All @@ -284,35 +275,34 @@ def train(hyp, tb_writer, opt, device):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

# Forward
pred = model(imgs)
# Autocast
with amp.autocast():
# Forward
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# print('WARNING: non-finite loss, ending training ', loss_items)
# return results

# Backward
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
scaler.scale(loss).backward()

# Optimize
if ni % accumulate == 0:
optimizer.step()
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema is not None:
ema.update(model)

# Print
if rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s)
Expand All @@ -330,7 +320,7 @@ def train(hyp, tb_writer, opt, device):
# Scheduler
scheduler.step()

# Only the first process in DDP mode is allowed to log or save checkpoints.
# DDP process 0 or single-GPU
if rank in [-1, 0]:
# mAP
if ema is not None:
Expand Down Expand Up @@ -377,7 +367,7 @@ def train(hyp, tb_writer, opt, device):

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt
# end epoch ----------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -429,10 +419,12 @@ def train(hyp, tb_writer, opt, device):
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()

# Resume
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights

if opt.local_rank in [-1, 0]:
check_git_status()
opt.cfg = check_file(opt.cfg) # check file
Expand All @@ -442,21 +434,20 @@ def train(hyp, tb_writer, opt, device):
with open(opt.hyp) as f:
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size
opt.world_size = 1
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
# DDP mode

# DDP mode
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

opt.world_size = dist.get_world_size()
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size

print(opt)

# Train
Expand All @@ -466,11 +457,12 @@ def train(hyp, tb_writer, opt, device):
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
else:
tb_writer = None

train(hyp, tb_writer, opt, device)

# Evolve hyperparameters (optional)
else:
assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!"
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'

tb_writer = None
opt.notest, opt.nosave = True, True # only test/save final epoch
Expand Down
4 changes: 2 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init_seeds(seed=0):
cudnn.benchmark = True


def select_device(device='', apex=False, batch_size=None):
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu'
Expand All @@ -36,7 +36,7 @@ def select_device(device='', apex=False, batch_size=None):
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
s = 'Using CUDA '
for i in range(0, ng):
if i == 1:
s = ' ' * len(s)
Expand Down