Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding mixed precision training support #390

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions layers/modules/multibox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ def forward(self, predictions, targets):
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))

# Hard Negative Mining
loss_c[pos] = 0 # filter out pos boxes for now
#loss_c[pos] = 0 # filter out pos boxes for now
#loss_c = loss_c.view(num, -1)
loss_c = loss_c.view(num, -1)
loss_c[pos] = 0 # filter out pos boxes for now
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
Expand All @@ -111,7 +113,11 @@ def forward(self, predictions, targets):

# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N

N = num_pos.data.sum()
#N = num_pos.data.sum()
N = num_pos.data.sum().double()
loss_l = loss_l.double()
loss_c = loss_c.double()

loss_l /= N
loss_c /= N
return loss_l, loss_c
44 changes: 36 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
import torch.utils.data as data
import numpy as np
import argparse

try:
from apex import amp
except ImportError:
amp = None

def str2bool(v):
return v.lower() in ("yes", "true", "t", "1")
Expand Down Expand Up @@ -51,6 +54,15 @@ def str2bool(v):
help='Use visdom for loss visualization')
parser.add_argument('--save_folder', default='weights/',
help='Directory for saving checkpoint models')
# Mixed precision training parameters
parser.add_argument('--apex', default=False, type=str2bool,
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)

args = parser.parse_args()


Expand All @@ -69,6 +81,13 @@ def str2bool(v):


def train():
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")

if args.dataset == 'COCO':
if args.dataset_root == VOC_ROOT:
if not os.path.exists(COCO_ROOT):
Expand Down Expand Up @@ -121,7 +140,10 @@ def train():
weight_decay=args.weight_decay)
criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
False, args.cuda)

if args.apex:
ssd_net, optimizer = amp.initialize(ssd_net, optimizer,
opt_level=args.apex_opt_level
)
net.train()
# loss counters
loc_loss = 0
Expand Down Expand Up @@ -177,18 +199,24 @@ def train():
optimizer.zero_grad()
loss_l, loss_c = criterion(out, targets)
loss = loss_l + loss_c
loss.backward()
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
t1 = time.time()
loc_loss += loss_l.data[0]
conf_loss += loss_c.data[0]
loc_loss += loss_l.data #[0]
conf_loss += loss_c.data #[0]

if iteration % 10 == 0:
print('timer: %.4f sec.' % (t1 - t0))
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
print('timer: %.4f sec. Throughput: %.4f' % (t1 - t0, args.batch_size/(t1-t0)/10))
#print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data), end=' ')

if args.visdom:
update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
#update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
update_vis_plot(iteration, loss_l.data, loss_c.data,
iter_plot, epoch_plot, 'append')

if iteration != 0 and iteration % 5000 == 0:
Expand Down