Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefetch #2

Merged
merged 3 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions dataset/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
def collate_fn_coco(batch):
images, annos = tuple(zip(*batch))
t_images = torch.empty((0, 3, 300, 300))
b_bboxes = []
b_labels = []
num_objs = torch.empty((len(images)), dtype=torch.int64)
b_bboxes = torch.empty((0, 4), dtype=torch.float32)
b_labels = torch.empty((0), dtype=torch.int64)
for i, image in enumerate(images):
r_width = 1 / image.shape[0]
r_height = 1 / image.shape[1]
t_image = torch.unsqueeze(image, dim=0)
t_images = torch.cat((t_images, t_image))
num_objs[i] = len(annos[i])
boxes = torch.empty((len(annos[i]), 4), dtype=torch.float32)
labels = torch.empty((len(annos[i])), dtype=torch.int64)
for num_obj, anno in enumerate(annos[i]):
boxes[num_obj][0] = anno['bbox'][0] * r_width
boxes[num_obj][1] = anno['bbox'][1] * r_height
boxes[num_obj][2] = (anno['bbox'][0] + anno['bbox'][2]) * r_width
boxes[num_obj][3] = (anno['bbox'][1] + anno['bbox'][3]) * r_height
labels[num_obj] = anno['category_id']
b_bboxes.append(boxes)
b_labels.append(labels)
for obj, anno in enumerate(annos[i]):
boxes[obj][0] = anno['bbox'][0] * r_width
boxes[obj][1] = anno['bbox'][1] * r_height
boxes[obj][2] = (anno['bbox'][0] + anno['bbox'][2]) * r_width
boxes[obj][3] = (anno['bbox'][1] + anno['bbox'][3]) * r_height
labels[obj] = anno['category_id']
b_bboxes = torch.cat((b_bboxes, boxes), axis=0)
b_labels = torch.cat((b_labels, labels), axis=0)

return t_images, b_bboxes, b_labels
return t_images, b_bboxes, b_labels, num_objs

COLLATE_FN = {'coco': collate_fn_coco}

Expand Down
12 changes: 9 additions & 3 deletions dataset/obj_data_prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ def __init__(self, loader):

def preload(self):
try:
self.next_input, self.next_box, self.next_target = next(self.loader)
self.next_box = torch.as_tensor(self.next_box, device='cuda')
self.next_input, self.next_box, \
self.next_target, self.num_obj_in_images \
= next(self.loader)
except StopIteration:
self.next_input = None
self.next_box = None
self.next_target = None
self.num_obj_in_images = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
Expand All @@ -33,6 +35,7 @@ def preload(self):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_box = self.next_box.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.num_obj_in_images = self.num_obj_in_images.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
Expand All @@ -52,11 +55,14 @@ def next(self):
image = self.next_input
box = self.next_box
target = self.next_target
num_obj_in_images = self.num_obj_in_images
if image is not None:
image.record_stream(torch.cuda.current_stream())
if box is not None:
box.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
if num_obj_in_images is not None:
num_obj_in_images.record_stream(torch.cuda.current_stream())
self.preload()
return image, box, target
return image, box, target, num_obj_in_images
91 changes: 52 additions & 39 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
train.py
"""
import os
import argparse

Expand All @@ -10,13 +13,15 @@
from net.ssd import SSD300
from net.ssd import MultiBoxLoss
from dataset.obj_dataloader import ObjTorchLoader
from dataset.obj_data_prefetcher import ObjDataPrefetcher


DEVICE = None

def train_for_one_step(model, criterion,
optimizer, loss_container,
inputs, boxes, labels):
inputs, b_boxes, b_labels,
num_objs):
# zero the parameter gradients
# optimizer.zero_grad()
for param in model.parameters():
Expand All @@ -25,10 +30,13 @@ def train_for_one_step(model, criterion,
# forward + backward + optimize
pred_locs, pred_cls_prob = model(inputs.to(DEVICE))

num_obj = len(boxes)
for n in range(num_obj):
boxes[n] = boxes[n].to(DEVICE)
labels[n] = labels[n].to(DEVICE)
serial = 0
boxes = []
labels = []
for num_obj in num_objs:
boxes.append(b_boxes[serial:serial + num_obj].to(DEVICE))
labels.append(b_labels[serial:serial + num_obj].to(DEVICE))
serial = num_obj
loss = criterion(pred_locs.to(DEVICE), pred_cls_prob.to(DEVICE),
boxes, labels)
loss.backward()
Expand All @@ -39,61 +47,67 @@ def train_for_one_step(model, criterion,


def train_for_one_epoch(model, criterion, optimizer,
dataset, epoch, step_per_epoch,
data_fetcher, epoch, step_per_epoch,
writer=None):
step = 0
losses = [0.]
step_to_draw_loss = 10
for images, boxes, labels in dataset.train_loader:
images, boxes, labels, num_obj_in_images = data_fetcher.next()
while images is not None:
train_for_one_step(model, criterion,
optimizer, losses,
images, boxes, labels)
images, boxes, labels,
num_obj_in_images)

if writer is not None and step % step_to_draw_loss == 0:
writer.add_scalar("Loss/train", losses[0], step + epoch * step_per_epoch)

images, boxes, labels, num_obj_in_images = data_fetcher.next()
step += 1


def train_loop(training_setup, epoch):
model = training_setup['model']
criterion = training_setup['criterion']
optim = training_setup['optimizer']
optimizer = training_setup['optimizer']
dataset = training_setup['dataset']
writer = training_setup['writer']

step_per_epoch = len(dataset.train_loader)
train_fetcher = ObjDataPrefetcher(dataset.train_loader)
for e in range(epoch):
model.train()
train_for_one_epoch(model, criterion, optim, dataset, e, step_per_epoch, writer=writer)
train_for_one_epoch(model, criterion, optimizer,
train_fetcher, e, step_per_epoch,
writer=writer)

# eval_result = eval_for_one_epoch(training_setup['dataset'], training_setup['model'])


def training_setup(args):
training_setup = {}
def setting_up(args):
setup = {}
backbone = IntermediateNetwork('resnet50', [5, 6]).to(DEVICE)
training_setup['model'] = SSD300(backbone, args.num_classes).to(DEVICE)
training_setup['preprocessor'] = transforms.Compose([transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
training_setup['dataset'] = ObjTorchLoader(args.dataset_name,
transform=training_setup['preprocessor'],
collate_fn_name=args.dataset_name,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size)
prior_boxes = training_setup['model'].priors_cxcy
training_setup['criterion'] = MultiBoxLoss(prior_boxes)
training_setup['optimizer'] = optim.SGD(training_setup['model'].parameters(),
lr=args.lr)
training_setup['writer'] = None
training_setup['save_model_path'] = args.save_model_path

return training_setup


def get_argments():
setup['model'] = SSD300(backbone, args.num_classes).to(DEVICE)
setup['preprocessor'] = transforms.Compose([transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
setup['dataset'] = ObjTorchLoader(args.dataset_name,
transform=setup['preprocessor'],
collate_fn_name=args.dataset_name,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size)
prior_boxes = setup['model'].priors_cxcy
setup['criterion'] = MultiBoxLoss(prior_boxes)
setup['optimizer'] = optim.SGD(setup['model'].parameters(),
lr=args.lr)
setup['writer'] = None
setup['save_model_path'] = args.save_model_path

return setup


def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='resnet18')
parser.add_argument('--dataset_name', default='coco')
Expand All @@ -108,18 +122,17 @@ def get_argments():
return parser.parse_args()


def train(args):
setup = training_setup(args)
train_loop(setup, epoch=args.training_epoch)
def train(arguments):
setup = setting_up(arguments)
train_loop(setup, epoch=arguments.training_epoch)


if __name__ == '__main__':
args = get_argments()
args = get_arguments()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.benchmark = True
import time
st = time.time()
train(args)
print(time.time() - st)

print(time.time() - st)