From e1efb7c7fb88713c1d8db21584b32f09152f6283 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 19 Aug 2021 18:31:58 +0800 Subject: [PATCH] Add coco metrics evaluation CLI (#148) * init commit * Move `train_net` to tools * Make eval_metric as CLI tools * Optimize the evaluation log * Support evaluating with torchvision's models * Fixing unit-test --- test/test_engine.py | 12 +-- tools/eval_metric.py | 136 ++++++++++++++++++++++++++ yolort/train.py => tools/train_net.py | 6 +- yolort/data/coco_eval.py | 14 ++- yolort/models/yolo.py | 4 +- 5 files changed, 158 insertions(+), 14 deletions(-) create mode 100644 tools/eval_metric.py rename yolort/train.py => tools/train_net.py (96%) diff --git a/test/test_engine.py b/test/test_engine.py index 2e21fdc3..246a3762 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -96,15 +96,15 @@ def test_vanilla_coco_evaluator(): coco = data_helper.get_coco_api_from_dataset(val_dataloader.dataset) coco_evaluator = COCOEvaluator(coco) # Load model - model = yolov5s(pretrained=True, score_thresh=0.001) + model = yolov5s(pretrained=True) model.eval() for images, targets in val_dataloader: preds = model(images) coco_evaluator.update(preds, targets) results = coco_evaluator.compute() - assert results['AP'] > 38.1 - assert results['AP50'] > 59.9 + assert results['AP'] > 37.8 + assert results['AP50'] > 59.6 def test_test_epoch_end(): @@ -118,15 +118,15 @@ def test_test_epoch_end(): val_dataloader = data_helper.get_dataloader(data_root=data_path, mode='val') # Load model - model = yolov5s(pretrained=True, score_thresh=0.001, annotation_path=annotation_file) + model = yolov5s(pretrained=True, annotation_path=annotation_file) # test step trainer = pl.Trainer(max_epochs=1) trainer.test(model, test_dataloaders=val_dataloader) # test epoch end results = model.evaluator.compute() - assert results['AP'] > 38.1 - assert results['AP50'] > 59.9 + assert results['AP'] > 37.8 + assert results['AP50'] > 59.6 def test_predict_with_vanilla_model(): diff --git a/tools/eval_metric.py b/tools/eval_metric.py new file mode 100644 index 00000000..e8744074 --- /dev/null +++ b/tools/eval_metric.py @@ -0,0 +1,136 @@ +from pathlib import Path +import io +import contextlib +import argparse +import torch +import torchvision + +import yolort + +from yolort.data import COCOEvaluator, _helper as data_helper +from yolort.data.coco import COCODetection +from yolort.data.transforms import default_val_transforms, collate_fn + + +def get_parser(): + parser = argparse.ArgumentParser('Evaluation CLI for yolort', add_help=True) + + parser.add_argument('--num_gpus', default=0, type=int, metavar='N', + help='Number of gpu utilizing (default: 0)') + # Model architecture + parser.add_argument('--arch', default='yolov5s', + help='Model structure to train') + parser.add_argument('--num_classes', default=80, type=int, + help='Number of classes') + parser.add_argument('--image_size', default=640, type=int, + help='Image size for evaluation (default: 640)') + # Dataset Configuration + parser.add_argument('--image_path', default='./data-bin/coco128/images/train2017', + help='Root path of the dataset containing images') + parser.add_argument('--annotation_path', default=None, + help='Path of the annotation file') + parser.add_argument('--eval_type', default='yolov5', + help='Type of category id maps, yolov5 use continuous [1, 80], ' + 'torchvision use discrete ids range in [1, 90]') + parser.add_argument('--batch_size', default=32, type=int, + help='Images per gpu, the total batch size is $NGPU x batch_size') + parser.add_argument('--num_workers', default=8, type=int, metavar='N', + help='Number of data loading workers (default: 8)') + + parser.add_argument('--output_dir', default='.', + help='Path where to save') + return parser + + +def prepare_data(image_root, annotation_file, batch_size, num_workers): + """ + Setup the coco dataset and dataloader for validation + """ + # Define the dataloader + data_set = COCODetection(image_root, annotation_file, default_val_transforms()) + + # We adopt the sequential sampler in order to repeat the experiment + sampler = torch.utils.data.SequentialSampler(data_set) + + data_loader = torch.utils.data.DataLoader( + data_set, + batch_size, + sampler=sampler, + drop_last=False, + collate_fn=collate_fn, + num_workers=num_workers, + ) + + return data_set, data_loader + + +def eval_metric(args): + if args.num_gpus == 0: + device = torch.device('cpu') + print('Set CPU mode.') + elif args.num_gpus == 1: + device = torch.device('cuda') + print('Set GPU mode.') + else: + raise NotImplementedError('Currently not supported multi-GPUs mode') + + # Prepare the dataset and dataloader for evaluation + image_path = Path(args.image_path) + annotation_path = Path(args.annotation_path) + + print('Loading annotations into memory...') + with contextlib.redirect_stdout(io.StringIO()): + data_set, data_loader = prepare_data( + image_path, + annotation_path, + args.batch_size, + args.num_workers, + ) + + coco_gt = data_helper.get_coco_api_from_dataset(data_set) + coco_evaluator = COCOEvaluator(coco_gt, eval_type=args.eval_type) + + # Model Definition and Initialization + if args.eval_type == 'yolov5': + model = yolort.models.__dict__[args.arch]( + pretrained=True, + num_classes=args.num_classes, + ) + elif args.eval_type == 'torchvision': + model = torchvision.models.detection.__dict__[args.arch]( + pretrained=True, + num_classes=args.num_classes, + ) + else: + raise NotImplementedError(f'Currently not supports eval type: {args.eval_type}') + + model = model.eval() + model = model.to(device) + + # COCO evaluation + print('Computing the mAP...') + with torch.no_grad(): + for images, targets in data_loader: + images = [image.to(device) for image in images] + preds = model(images) + coco_evaluator.update(preds, targets) + + results = coco_evaluator.compute() + + # Format the results + # coco_evaluator.derive_coco_results() + + # mAP results + print(f"The evaluated mAP 0.5:095 is {results['AP']:0.3f}, " + f"and mAP 0.5 is {results['AP50']:0.3f}.") + + +def cli_main(): + parser = get_parser() + args = parser.parse_args() + print(f'Command Line Args: {args}') + eval_metric(args) + + +if __name__ == "__main__": + cli_main() diff --git a/yolort/train.py b/tools/train_net.py similarity index 96% rename from yolort/train.py rename to tools/train_net.py index 6d78bd15..49a3be7f 100644 --- a/yolort/train.py +++ b/tools/train_net.py @@ -3,12 +3,12 @@ import pytorch_lightning as pl -from .data import COCODetectionDataModule -from . import models +from yolort import models +from yolort.data import COCODetectionDataModule def get_args_parser(): - parser = argparse.ArgumentParser('You only look once detector', add_help=False) + parser = argparse.ArgumentParser('You only look once detector', add_help=True) parser.add_argument('--arch', default='yolov5s', help='model structure to train') diff --git a/yolort/data/coco_eval.py b/yolort/data/coco_eval.py index 95f5f875..fb751439 100644 --- a/yolort/data/coco_eval.py +++ b/yolort/data/coco_eval.py @@ -37,6 +37,7 @@ def __init__( self, coco_gt: Union[str, PosixPath, COCO], iou_type: str = 'bbox', + eval_type: str = 'yolov5', compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -49,6 +50,8 @@ def __init__( - PosixPath: a json file in COCO's result format, and is wrapped with Path. - COCO: COCO api iou_type (str): iou type to compute. + eval_type (str): The categories predicted by yolov5 are continuous [1-80], which is + different from torchvision's discrete 91 categories. Default: yolov5. """ super().__init__( compute_on_step=compute_on_step, @@ -63,10 +66,15 @@ def __init__( elif isinstance(coco_gt, COCO): coco_gt = copy.deepcopy(coco_gt) else: - raise NotImplementedError(f"Currently not support type {type(coco_gt)}") + raise NotImplementedError(f'Currently not supports type {type(coco_gt)}') self.coco_gt = coco_gt - self.contiguous_to_json_category = coco_gt.getCatIds() + if eval_type == 'yolov5': + self.category_id_maps = coco_gt.getCatIds() + elif eval_type == 'torchvision': + self.category_id_maps = list(range(coco_gt.getCatIds()[-1] + 1)) + else: + raise NotImplementedError(f'Currently not supports eval type {eval_type}') self.iou_type = iou_type self.coco_eval = COCOeval(coco_gt, iouType=iou_type) @@ -198,7 +206,7 @@ def prepare_for_coco_detection(self, predictions): [ { "image_id": original_id, - "category_id": self.contiguous_to_json_category[labels[k]], + "category_id": self.category_id_maps[labels[k]], "bbox": box, "score": scores[k], } diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 56716cf1..e5410042 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -60,8 +60,8 @@ def __init__( iou_thresh: float = 0.5, criterion: Optional[Callable[..., Dict[str, Tensor]]] = None, # Post Process parameter - score_thresh: float = 0.05, - nms_thresh: float = 0.5, + score_thresh: float = 0.005, + nms_thresh: float = 0.45, detections_per_img: int = 300, post_process: Optional[nn.Module] = None, ):