Skip to content

Commit

Permalink
Add coco metrics evaluation CLI (#148)
Browse files Browse the repository at this point in the history
* init commit

* Move `train_net` to tools

* Make eval_metric as CLI tools

* Optimize the evaluation log

* Support evaluating with torchvision's models

* Fixing unit-test
  • Loading branch information
zhiqwang authored Aug 19, 2021
1 parent f34194c commit e1efb7c
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 14 deletions.
12 changes: 6 additions & 6 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def test_vanilla_coco_evaluator():
coco = data_helper.get_coco_api_from_dataset(val_dataloader.dataset)
coco_evaluator = COCOEvaluator(coco)
# Load model
model = yolov5s(pretrained=True, score_thresh=0.001)
model = yolov5s(pretrained=True)
model.eval()
for images, targets in val_dataloader:
preds = model(images)
coco_evaluator.update(preds, targets)

results = coco_evaluator.compute()
assert results['AP'] > 38.1
assert results['AP50'] > 59.9
assert results['AP'] > 37.8
assert results['AP50'] > 59.6


def test_test_epoch_end():
Expand All @@ -118,15 +118,15 @@ def test_test_epoch_end():
val_dataloader = data_helper.get_dataloader(data_root=data_path, mode='val')

# Load model
model = yolov5s(pretrained=True, score_thresh=0.001, annotation_path=annotation_file)
model = yolov5s(pretrained=True, annotation_path=annotation_file)

# test step
trainer = pl.Trainer(max_epochs=1)
trainer.test(model, test_dataloaders=val_dataloader)
# test epoch end
results = model.evaluator.compute()
assert results['AP'] > 38.1
assert results['AP50'] > 59.9
assert results['AP'] > 37.8
assert results['AP50'] > 59.6


def test_predict_with_vanilla_model():
Expand Down
136 changes: 136 additions & 0 deletions tools/eval_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from pathlib import Path
import io
import contextlib
import argparse
import torch
import torchvision

import yolort

from yolort.data import COCOEvaluator, _helper as data_helper
from yolort.data.coco import COCODetection
from yolort.data.transforms import default_val_transforms, collate_fn


def get_parser():
parser = argparse.ArgumentParser('Evaluation CLI for yolort', add_help=True)

parser.add_argument('--num_gpus', default=0, type=int, metavar='N',
help='Number of gpu utilizing (default: 0)')
# Model architecture
parser.add_argument('--arch', default='yolov5s',
help='Model structure to train')
parser.add_argument('--num_classes', default=80, type=int,
help='Number of classes')
parser.add_argument('--image_size', default=640, type=int,
help='Image size for evaluation (default: 640)')
# Dataset Configuration
parser.add_argument('--image_path', default='./data-bin/coco128/images/train2017',
help='Root path of the dataset containing images')
parser.add_argument('--annotation_path', default=None,
help='Path of the annotation file')
parser.add_argument('--eval_type', default='yolov5',
help='Type of category id maps, yolov5 use continuous [1, 80], '
'torchvision use discrete ids range in [1, 90]')
parser.add_argument('--batch_size', default=32, type=int,
help='Images per gpu, the total batch size is $NGPU x batch_size')
parser.add_argument('--num_workers', default=8, type=int, metavar='N',
help='Number of data loading workers (default: 8)')

parser.add_argument('--output_dir', default='.',
help='Path where to save')
return parser


def prepare_data(image_root, annotation_file, batch_size, num_workers):
"""
Setup the coco dataset and dataloader for validation
"""
# Define the dataloader
data_set = COCODetection(image_root, annotation_file, default_val_transforms())

# We adopt the sequential sampler in order to repeat the experiment
sampler = torch.utils.data.SequentialSampler(data_set)

data_loader = torch.utils.data.DataLoader(
data_set,
batch_size,
sampler=sampler,
drop_last=False,
collate_fn=collate_fn,
num_workers=num_workers,
)

return data_set, data_loader


def eval_metric(args):
if args.num_gpus == 0:
device = torch.device('cpu')
print('Set CPU mode.')
elif args.num_gpus == 1:
device = torch.device('cuda')
print('Set GPU mode.')
else:
raise NotImplementedError('Currently not supported multi-GPUs mode')

# Prepare the dataset and dataloader for evaluation
image_path = Path(args.image_path)
annotation_path = Path(args.annotation_path)

print('Loading annotations into memory...')
with contextlib.redirect_stdout(io.StringIO()):
data_set, data_loader = prepare_data(
image_path,
annotation_path,
args.batch_size,
args.num_workers,
)

coco_gt = data_helper.get_coco_api_from_dataset(data_set)
coco_evaluator = COCOEvaluator(coco_gt, eval_type=args.eval_type)

# Model Definition and Initialization
if args.eval_type == 'yolov5':
model = yolort.models.__dict__[args.arch](
pretrained=True,
num_classes=args.num_classes,
)
elif args.eval_type == 'torchvision':
model = torchvision.models.detection.__dict__[args.arch](
pretrained=True,
num_classes=args.num_classes,
)
else:
raise NotImplementedError(f'Currently not supports eval type: {args.eval_type}')

model = model.eval()
model = model.to(device)

# COCO evaluation
print('Computing the mAP...')
with torch.no_grad():
for images, targets in data_loader:
images = [image.to(device) for image in images]
preds = model(images)
coco_evaluator.update(preds, targets)

results = coco_evaluator.compute()

# Format the results
# coco_evaluator.derive_coco_results()

# mAP results
print(f"The evaluated mAP 0.5:095 is {results['AP']:0.3f}, "
f"and mAP 0.5 is {results['AP50']:0.3f}.")


def cli_main():
parser = get_parser()
args = parser.parse_args()
print(f'Command Line Args: {args}')
eval_metric(args)


if __name__ == "__main__":
cli_main()
6 changes: 3 additions & 3 deletions yolort/train.py → tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import pytorch_lightning as pl

from .data import COCODetectionDataModule
from . import models
from yolort import models
from yolort.data import COCODetectionDataModule


def get_args_parser():
parser = argparse.ArgumentParser('You only look once detector', add_help=False)
parser = argparse.ArgumentParser('You only look once detector', add_help=True)

parser.add_argument('--arch', default='yolov5s',
help='model structure to train')
Expand Down
14 changes: 11 additions & 3 deletions yolort/data/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self,
coco_gt: Union[str, PosixPath, COCO],
iou_type: str = 'bbox',
eval_type: str = 'yolov5',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -49,6 +50,8 @@ def __init__(
- PosixPath: a json file in COCO's result format, and is wrapped with Path.
- COCO: COCO api
iou_type (str): iou type to compute.
eval_type (str): The categories predicted by yolov5 are continuous [1-80], which is
different from torchvision's discrete 91 categories. Default: yolov5.
"""
super().__init__(
compute_on_step=compute_on_step,
Expand All @@ -63,10 +66,15 @@ def __init__(
elif isinstance(coco_gt, COCO):
coco_gt = copy.deepcopy(coco_gt)
else:
raise NotImplementedError(f"Currently not support type {type(coco_gt)}")
raise NotImplementedError(f'Currently not supports type {type(coco_gt)}')

self.coco_gt = coco_gt
self.contiguous_to_json_category = coco_gt.getCatIds()
if eval_type == 'yolov5':
self.category_id_maps = coco_gt.getCatIds()
elif eval_type == 'torchvision':
self.category_id_maps = list(range(coco_gt.getCatIds()[-1] + 1))
else:
raise NotImplementedError(f'Currently not supports eval type {eval_type}')

self.iou_type = iou_type
self.coco_eval = COCOeval(coco_gt, iouType=iou_type)
Expand Down Expand Up @@ -198,7 +206,7 @@ def prepare_for_coco_detection(self, predictions):
[
{
"image_id": original_id,
"category_id": self.contiguous_to_json_category[labels[k]],
"category_id": self.category_id_maps[labels[k]],
"bbox": box,
"score": scores[k],
}
Expand Down
4 changes: 2 additions & 2 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(
iou_thresh: float = 0.5,
criterion: Optional[Callable[..., Dict[str, Tensor]]] = None,
# Post Process parameter
score_thresh: float = 0.05,
nms_thresh: float = 0.5,
score_thresh: float = 0.005,
nms_thresh: float = 0.45,
detections_per_img: int = 300,
post_process: Optional[nn.Module] = None,
):
Expand Down

0 comments on commit e1efb7c

Please sign in to comment.