Skip to content

Commit

Permalink
Refactor dataloader getter in engine UnitTest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Apr 13, 2021
1 parent f122aa7 commit 6e325a8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 42 deletions.
46 changes: 45 additions & 1 deletion test/dataset_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import random
from pathlib import Path

import torch
from torch.utils.data import Dataset

from torchvision import ops

from yolort.data.coco import CocoDetection
from yolort.data.transforms import (
collate_fn,
default_train_transforms,
default_val_transforms,
)
from yolort.utils import prepare_coco128


class DummyCOCODetectionDataset(Dataset):
"""
Expand Down Expand Up @@ -72,3 +82,37 @@ def __getitem__(self, idx: int):
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
image_id = torch.tensor([idx])
return img, {"image_id": image_id, "boxes": boxes, "labels": labels}


def _get_data_loader(mode: str = 'train', batch_size: int = 4):
# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

if mode == 'train':
dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
elif mode == 'val':
dataset = CocoDetection(image_root, annotation_file, default_val_transforms())
else:
raise NotImplementedError(f"Currently not support {mode} mode")

# We adopt the sequential sampler in order to repeat the experiment
sampler = torch.utils.data.SequentialSampler(dataset)

loader = torch.utils.data.DataLoader(
dataset,
batch_size,
sampler=sampler,
drop_last=False,
collate_fn=collate_fn,
num_workers=0,
)

return loader
15 changes: 7 additions & 8 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

from yolort.data import COCOEvaluator, DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms
from yolort.data.transforms import default_train_transforms
from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset
from .dataset_utils import DummyCOCODetectionDataset, _get_data_loader

from typing import Dict


class DataPipelineTester(unittest.TestCase):
def test_vanilla_dataloader(self):
def test_vanilla_dataset(self):
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
Expand All @@ -33,11 +33,9 @@ def test_vanilla_dataloader(self):
self.assertIsInstance(image, Tensor)
self.assertIsInstance(target, Dict)

batch_size = 4
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
def test_vanilla_dataloader(self):
batch_size = 8
data_loader = _get_data_loader(mode='train', batch_size=batch_size)
# Test the dataloader
images, targets = next(iter(data_loader))

Expand Down Expand Up @@ -76,6 +74,7 @@ def test_prepare_coco128(self):
annotation_file = data_path / coco128_dirname / 'annotations' / 'instances_train2017.json'
self.assertTrue(annotation_file.is_file())

@unittest.skip("Currently it isn't well implemented")
def test_coco_evaluator(self):
coco_evaluator = COCOEvaluator()
pass
42 changes: 9 additions & 33 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import unittest

import torch
Expand All @@ -10,16 +9,12 @@
import pytorch_lightning as pl

from yolort.data import DetectionDataModule
from yolort.data.coco import CocoDetection
from yolort.data.transforms import collate_fn, default_train_transforms

from yolort.models.yolo import yolov5_darknet_pan_s_r31
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models import yolov5s

from yolort.utils import prepare_coco128

from .dataset_utils import DummyCOCODetectionDataset
from .dataset_utils import DummyCOCODetectionDataset, _get_data_loader

from typing import Dict

Expand Down Expand Up @@ -65,26 +60,9 @@ def test_train_with_vanilla_module(self):
# Define the device
device = torch.device('cpu')

# Prepare the datasets for training
# Acquire the images and labels from the coco128 dataset
data_path = Path('data-bin')
coco128_dirname = 'coco128'
coco128_path = data_path / coco128_dirname
image_root = coco128_path / 'images' / 'train2017'
annotation_file = coco128_path / 'annotations' / 'instances_train2017.json'

if not annotation_file.is_file():
prepare_coco128(data_path, dirname=coco128_dirname)

batch_size = 4

dataset = CocoDetection(image_root, annotation_file, default_train_transforms())
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0)
train_dataloader = _get_data_loader(mode='train')
# Sample a pair of images/targets
images, targets = next(iter(data_loader))
images, targets = next(iter(train_dataloader))
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

Expand All @@ -109,19 +87,17 @@ def test_train_one_epoch(self):
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, data_module)

def test_test_dataloaders(self):
# Config dataset
num_samples = 128
batch_size = 4
# Setup the DataModule
train_dataset = DummyCOCODetectionDataset(num_samples=num_samples)
data_module = DetectionDataModule(train_dataset, batch_size=batch_size)
@unittest.skip("Currently it isn't well implemented")
def test_test_with_dataloader(self):
# Get dataloader to test
val_dataloader = _get_data_loader(mode='val')

# Load model
model = yolov5s(pretrained=True)
model.eval()
# Trainer
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=data_module.val_dataloader(batch_size=batch_size))
trainer.test(model, test_dataloaders=val_dataloader)

def test_predict_with_vanilla_model(self):
# Set image inputs
Expand Down

0 comments on commit 6e325a8

Please sign in to comment.