diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 01fdf800fc9..f76428e79d6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -265,7 +265,7 @@ jobs: - name: Build and install run: pip install -e . - name: Run unittests - run: coverage run --branch --source mmdet -m pytest tests -sv + run: coverage run --branch --source mmdet -m pytest tests - name: Generate coverage report run: | coverage xml diff --git a/configs/openimages/README.md b/configs/openimages/README.md index 6d954217280..fc8c18190b0 100644 --- a/configs/openimages/README.md +++ b/configs/openimages/README.md @@ -1,6 +1,9 @@ # Open Images Dataset - +> [Open Images Dataset](https://arxiv.org/abs/1811.00982) + + + ## Abstract @@ -90,14 +93,14 @@ training/testing by using `tools/misc/get_image_metas.py`. │ │ │ ├── class-descriptions-boxable.csv │ │ │ ├── oidv6-train-annotations-bbox.scv │ │ │ ├── validation-annotations-bbox.csv - │ │ │ ├── validation-annotations-human-imagelabels-boxable.csv # is not necessary + │ │ │ ├── validation-annotations-human-imagelabels-boxable.csv # should set `load_image_level_labels=False` if not use │ │ │ ├── validation-image-metas.pkl # get from script │ │ ├── challenge2019 │ │ │ ├── challenge-2019-train-detection-bbox.txt │ │ │ ├── challenge-2019-validation-detection-bbox.txt │ │ │ ├── class_label_tree.np │ │ │ ├── class_sample_train.pkl - │ │ │ ├── challenge-2019-validation-detection-human-imagelabels.csv # download from official website, not necessary + │ │ │ ├── challenge-2019-validation-detection-human-imagelabels.csv # download from official website │ │ │ ├── challenge-2019-validation-metas.pkl # get from script │ │ ├── OpenImages │ │ │ ├── train # training images @@ -120,6 +123,21 @@ users can should set `data.val.load_image_level_labels=False` and `data.test.loa | Architecture | Backbone | Style | Lr schd | Sampler | Mem (GB) | Inf time (fps) | box AP | Config | Download | |:------------:|:---------:|:-------:|:-------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:| | Faster R-CNN | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 51.6 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159-e87ab7ce.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159.log.json) | -| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 54.5 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252-46380cde.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252.log.json) | +| Faster R-CNN | R-50 | pytorch | 1x | Class Aware Sampler | 7.7 | - | 60.0 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424-98c630e5.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424-98c630e5.log.json) | +| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Group Sampler | 7.7 | - | 54.9 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100-0e79e5df.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100.log.json) | +| Faster R-CNN (Challenge 2019) | R-50 | pytorch | 1x | Class Aware Sampler | 7.1 | - | 65.0 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021-34c402d9.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021.log.json) | | Retinanet | R-50 | pytorch | 1x | Group Sampler | 6.6 | - | 61.5 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/retinanet_r50_fpn_32x2_1x_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954-d2ae5462.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954.log.json) | | SSD | VGG16 | pytorch | 36e | Group Sampler | 10.8 | - | 35.4 |[config](https://github.com/open-mmlab/mmdetection/tree/master/configs/openimages/ssd300_32x8_36e_openimages.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232-dce93846.pth) | [log](ttps://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232.log.json) | + +**Notes:** + +- 'cas' is short for 'Class Aware Sampler' + +### Results of consider image level labels + +| Architecture | Sampler | Consider Image Level Labels | box AP| +|:------------:|:-------:|:---------------------------:|:-----:| +|Faster R-CNN r50 (Challenge 2019)| Group Sampler| w/o | 62.19 | +|Faster R-CNN r50 (Challenge 2019)| Group Sampler| w/ | 54.87 | +|Faster R-CNN r50 (Challenge 2019)| Class Aware Sampler| w/o | 71.77 | +|Faster R-CNN r50 (Challenge 2019)| Class Aware Sampler| w/ | 64.98 | diff --git a/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py b/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py new file mode 100644 index 00000000000..6056c58a468 --- /dev/null +++ b/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py @@ -0,0 +1,4 @@ +_base_ = ['faster_rcnn_r50_fpn_32x2_1x_openimages.py'] + +# Use ClassAwareSampler +data = dict(use_class_aware_sampler=True) diff --git a/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py b/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py new file mode 100644 index 00000000000..3a502e824f1 --- /dev/null +++ b/configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py @@ -0,0 +1,4 @@ +_base_ = ['faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py'] + +# Use ClassAwareSampler +data = dict(use_class_aware_sampler=True) diff --git a/configs/openimages/metafile.yml b/configs/openimages/metafile.yml index a3e7a8acfbe..9be17261278 100644 --- a/configs/openimages/metafile.yml +++ b/configs/openimages/metafile.yml @@ -1,20 +1,14 @@ -Collections: - - Name: Open Images Dataset - Paper: - URL: https://arxiv.org/abs/1811.00982 - Title: 'The Open Images Dataset V4: Unified image classification, object detection, and visual relationship detection at scale' - README: configs/openimages/README.md - Code: - URL: https://github.com/open-mmlab/mmdetection/blob/v2.20.0/mmdet/datasets/openimages.py#L21 - Version: v2.20.0 - Models: - Name: faster_rcnn_r50_fpn_32x2_1x_openimages - In Collection: Open Images Dataset + In Collection: Faster R-CNN Config: configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages.py Metadata: Training Memory (GB): 7.7 Epochs: 12 + Training Data: Open Images v6 + Training Techniques: + - SGD with Momentum + - Weight Decay Results: - Task: Object Detection Dataset: Open Images v6 @@ -23,11 +17,15 @@ Models: Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_20211130_231159-e87ab7ce.pth - Name: retinanet_r50_fpn_32x2_1x_openimages - In Collection: Open Images Dataset + In Collection: RetinaNet Config: configs/openimages/retinanet_r50_fpn_32x2_1x_openimages.py Metadata: Training Memory (GB): 6.6 Epochs: 12 + Training Data: Open Images v6 + Training Techniques: + - SGD with Momentum + - Weight Decay Results: - Task: Object Detection Dataset: Open Images v6 @@ -36,11 +34,15 @@ Models: Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/retinanet_r50_fpn_32x2_1x_openimages/retinanet_r50_fpn_32x2_1x_openimages_20211223_071954-d2ae5462.pth - Name: ssd300_32x8_36e_openimages - In Collection: Open Images Dataset + In Collection: SSD Config: configs/openimages/ssd300_32x8_36e_openimages Metadata: Training Memory (GB): 10.8 - Epochs: 12 + Epochs: 36 + Training Data: Open Images v6 + Training Techniques: + - SGD with Momentum + - Weight Decay Results: - Task: Object Detection Dataset: Open Images v6 @@ -49,14 +51,52 @@ Models: Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/ssd300_32x8_36e_openimages/ssd300_32x8_36e_openimages_20211224_000232-dce93846.pth - Name: faster_rcnn_r50_fpn_32x2_1x_openimages_challenge - In Collection: Open Images Dataset + In Collection: Faster R-CNN Config: configs/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge.py Metadata: Training Memory (GB): 7.7 Epochs: 12 + Training Data: Open Images Challenge 2019 + Training Techniques: + - SGD with Momentum + - Weight Decay + Results: + - Task: Object Detection + Dataset: Open Images Challenge 2019 + Metrics: + box AP: 54.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20220114_045100-0e79e5df.pth + + - Name: faster_rcnn_r50_fpn_32x2_cas_1x_openimages + In Collection: Faster R-CNN + Config: configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages.py + Metadata: + Training Memory (GB): 7.7 + Epochs: 12 + Training Data: Open Images Challenge 2019 + Training Techniques: + - SGD with Momentum + - Weight Decay + Results: + - Task: Object Detection + Dataset: Open Images Challenge 2019 + Metrics: + box AP: 60.0 + Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_20220306_202424-98c630e5.pth + + - Name: faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge + In Collection: Faster R-CNN + Config: configs/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge.py + Metadata: + Training Memory (GB): 7.1 + Epochs: 12 + Training Data: Open Images Challenge 2019 + Training Techniques: + - SGD with Momentum + - Weight Decay Results: - Task: Object Detection - Dataset: Open Images Challenge 2019W + Dataset: Open Images Challenge 2019 Metrics: - box AP: 54.5 - Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_1x_openimages_challenge_20211229_071252-46380cde.pth + box AP: 65.0 + Weights: https://download.openmmlab.com/mmdetection/v2.0/openimages/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge/faster_rcnn_r50_fpn_32x2_cas_1x_openimages_challenge_20220221_192021-34c402d9.pth diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index f2c14e9f10f..9b99e08d8f0 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -105,8 +105,9 @@ def train_detector(model, dist=distributed, seed=cfg.seed, runner_type=runner_type, - persistent_workers=cfg.data.get('persistent_workers', False)) - for ds in dataset + persistent_workers=cfg.data.get('persistent_workers', False), + use_class_aware_sampler=cfg.data.get('use_class_aware_sampler', + False)) for ds in dataset ] # put model on gpus diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py index 30e1ee91a05..ca061c512e7 100644 --- a/mmdet/datasets/builder.py +++ b/mmdet/datasets/builder.py @@ -11,8 +11,8 @@ from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version from torch.utils.data import DataLoader -from .samplers import (DistributedGroupSampler, DistributedSampler, - GroupSampler, InfiniteBatchSampler, +from .samplers import (ClassAwareSampler, DistributedGroupSampler, + DistributedSampler, GroupSampler, InfiniteBatchSampler, InfiniteGroupBatchSampler) if platform.system() != 'Windows': @@ -92,6 +92,7 @@ def build_dataloader(dataset, seed=None, runner_type='EpochBasedRunner', persistent_workers=False, + use_class_aware_sampler=False, **kwargs): """Build PyTorch DataLoader. @@ -114,6 +115,8 @@ def build_dataloader(dataset, the worker processes after a dataset has been consumed once. This allows to maintain the workers `Dataset` instances alive. This argument is only valid when PyTorch>=1.7.0. Default: False. + use_class_aware_sampler (bool): Whether to use `ClassAwareSampler` + during training. Default: False. kwargs: any keyword argument to be used to initialize DataLoader Returns: @@ -152,7 +155,12 @@ def build_dataloader(dataset, batch_size = 1 sampler = None else: - if dist: + if use_class_aware_sampler: + # ClassAwareSampler can be used in both distributed and + # non-distributed training. + sampler = ClassAwareSampler( + dataset, samples_per_gpu, world_size, rank, seed=seed) + elif dist: # DistributedGroupSampler will definitely shuffle the data to # satisfy that images on each GPU are in the same group if shuffle: diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index e449150abce..45dd7869cc0 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -74,8 +74,8 @@ def __init__(self, self.proposal_file = proposal_file self.test_mode = test_mode self.filter_empty_gt = filter_empty_gt - self.CLASSES = self.get_classes(classes) self.file_client = mmcv.FileClient(**file_client_args) + self.CLASSES = self.get_classes(classes) # join paths if data_root is specified if self.data_root is not None: @@ -285,6 +285,28 @@ def get_classes(cls, classes=None): return class_names + def get_label_dict(self): + """Get per-label image list in the current dataset, which will be used + in :class:`ClassAwareSampler`. + + Returns: + dict[list]: A ddict of per-label image list, + the item of the dict indicates a label index, + corresponds to the image index that contains the label. + """ + label_dict = dict() + if self.CLASSES is None: + raise ValueError('CLASSES can not be None') + # sort the label index + for i in range(len(self.CLASSES)): + label_dict[i] = [] + data_infos = [self.get_ann_info(idx) for idx in range(len(self))] + for i, ann in enumerate(data_infos): + labels = np.unique(ann['labels']) + for label in labels: + label_dict[label].append(i) + return label_dict + def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" diff --git a/mmdet/datasets/openimages.py b/mmdet/datasets/openimages.py index d601b482403..099f0df1438 100644 --- a/mmdet/datasets/openimages.py +++ b/mmdet/datasets/openimages.py @@ -22,38 +22,41 @@ class OpenImagesDataset(CustomDataset): """Open Images dataset for detection. Args: - label_file (str): File path of the label description file that - maps the classes names in MID format to their short - descriptions. - image_level_ann_file (str): Image level annotation, which is used - in evaluation. - get_supercategory (bool): Whether to get parent class of the - current class. Default: True. - hierarchy_file (str): The file path of the class hierarchy. - Default: None. - get_metas (bool): Whether to get image metas in testing or - validation time. This should be `True` during evaluation. - Default: True. The OpenImages annotations do not have image - metas (width and height of the image), which will be used - during evaluation. We provide two ways to get image metas - in `OpenImagesDataset`: - - - 1. `load from file`: Load image metas from pkl file, which - is suggested to use. We provided a script to get image metas: - `tools/misc/get_image_metas.py`, which need to run - this script before training/testing. Please refer to - `config/openimages/README.md` for more details. - - - 2. `load from pipeline`, which will get image metas during - test time. However, this may reduce the inference speed, - especially when using distribution. - - load_from_file (bool): Whether to get image metas from pkl file. - meta_file (str): File path to get image metas. - filter_labels (bool): Whether filter unannotated classes. - Default: True. - load_image_level_labels (bool): Whether load and consider image - level labels during evaluation. Default: True. + label_file (str): File path of the label description file that + maps the classes names in MID format to their short + descriptions. + image_level_ann_file (str): Image level annotation, which is used + in evaluation. + get_supercategory (bool): Whether to get parent class of the + current class. Default: True. + hierarchy_file (str): The file path of the class hierarchy. + Default: None. + get_metas (bool): Whether to get image metas in testing or + validation time. This should be `True` during evaluation. + Default: True. The OpenImages annotations do not have image + metas (width and height of the image), which will be used + during evaluation. We provide two ways to get image metas + in `OpenImagesDataset`: + + - 1. `load from file`: Load image metas from pkl file, which + is suggested to use. We provided a script to get image metas: + `tools/misc/get_image_metas.py`, which need to run + this script before training/testing. Please refer to + `config/openimages/README.md` for more details. + + - 2. `load from pipeline`, which will get image metas during + test time. However, this may reduce the inference speed, + especially when using distribution. + + load_from_file (bool): Whether to get image metas from pkl file. + meta_file (str): File path to get image metas. + filter_labels (bool): Whether filter unannotated classes. + Default: True. + load_image_level_labels (bool): Whether to load and consider image + level labels during evaluation. Default: True. + num_sample_class (int): The number of samples taken from each + per-label list, which is used in :class:`ClassAwareSampler`. + Default: 1 """ def __init__(self, @@ -66,6 +69,7 @@ def __init__(self, meta_file='', filter_labels=True, load_image_level_labels=True, + num_sample_class=1, **kwargs): self.cat2label = defaultdict(str) self.index_dict = {} @@ -91,6 +95,7 @@ def __init__(self, self.test_img_metas = [] self.test_img_shapes = [] self.load_from_pipeline = False if load_from_file else True + self.num_sample_class = num_sample_class def get_classes_from_csv(self, label_file): """Get classes name from file. diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index fc68fc3d22f..eaad75eb124 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -252,16 +252,21 @@ def _load_bboxes(self, results): results['gt_bboxes'] = ann_info['bboxes'].copy() if self.denorm_bbox: - h, w = results['img_shape'][:2] bbox_num = results['gt_bboxes'].shape[0] if bbox_num != 0: + h, w = results['img_shape'][:2] results['gt_bboxes'][:, 0::2] *= w results['gt_bboxes'][:, 1::2] *= h - results['gt_bboxes'] = results['gt_bboxes'].astype(np.float32) gt_bboxes_ignore = ann_info.get('bboxes_ignore', None) if gt_bboxes_ignore is not None: results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy() + if self.denorm_bbox: + ignore_bbox_num = results['gt_bboxes_ignore'].shape[0] + if ignore_bbox_num != 0: + h, w = results['img_shape'][:2] + results['gt_bboxes_ignore'][:, 0::2] *= w + results['gt_bboxes_ignore'][:, 1::2] *= h results['bbox_fields'].append('gt_bboxes_ignore') results['bbox_fields'].append('gt_bboxes') diff --git a/mmdet/datasets/samplers/__init__.py b/mmdet/datasets/samplers/__init__.py index c3daa678ee3..a4c7ea135af 100644 --- a/mmdet/datasets/samplers/__init__.py +++ b/mmdet/datasets/samplers/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .class_aware_sampler import ClassAwareSampler from .distributed_sampler import DistributedSampler from .group_sampler import DistributedGroupSampler, GroupSampler from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler __all__ = [ 'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler', - 'InfiniteGroupBatchSampler', 'InfiniteBatchSampler' + 'InfiniteGroupBatchSampler', 'InfiniteBatchSampler', 'ClassAwareSampler' ] diff --git a/mmdet/datasets/samplers/class_aware_sampler.py b/mmdet/datasets/samplers/class_aware_sampler.py new file mode 100644 index 00000000000..cb06cd2fad1 --- /dev/null +++ b/mmdet/datasets/samplers/class_aware_sampler.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from collections import defaultdict + +import torch +from mmcv.runner import get_dist_info +from torch.utils.data import Sampler + + +class ClassAwareSampler(Sampler): + r"""A class-aware sampling strategy to effectively tackle the + non-uniform class distribution. The length of the training data is + consistent with source data. + + Simple improvements based on `Relay Backpropagation for Effective + Learning of Deep Convolutional Neural Networks + `_ + + The implementation logic is referred to + https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py + + Args: + dataset: Dataset used for sampling. + samples_per_gpu (int): When model is :obj:`DistributedDataParallel`, + it is the number of training samples on each GPU. + When model is :obj:`DataParallel`, it is + `num_gpus * samples_per_gpu`. + Default : 1. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + """ + + def __init__(self, + dataset, + samples_per_gpu=1, + num_replicas=None, + rank=None, + seed=0): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + + self.dataset = dataset + self.num_replicas = num_replicas + self.samples_per_gpu = samples_per_gpu + self.rank = rank + self.epoch = 0 + self.seed = seed if seed is not None else 0 + + # The number of samples taken from each per-label list + self.num_sample_class = dataset.num_sample_class \ + if hasattr(dataset, 'num_sample_class') else 1 + # Get per-label image list from dataset + assert hasattr(dataset, 'get_label_dict'), \ + 'dataset must have `get_label_dict` function' + self.class_dict = dataset.get_label_dict() + + self.num_samples = int( + math.ceil( + len(self.dataset) * 1.0 / self.num_replicas / + self.samples_per_gpu)) * self.samples_per_gpu + self.total_size = self.num_samples * self.num_replicas + self.class_num = len(self.class_dict.keys()) + self.class_num_list = [ + len(self.class_dict[i]) for i in range(self.class_num) + ] + # filter labels without images + self.class_index = [ + i for i, length in enumerate(self.class_num_list) if length != 0 + ] + self.class_num = len(self.class_index) + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + # initialize label list + label_iter_list = RandomCycleIter(self.class_index, generator=g) + # initialize each per-label image list + data_iter_dict = defaultdict(int) + for i in self.class_index: + data_iter_dict[i] = RandomCycleIter( + self.class_dict[i], generator=g) + + def gen_class_num_indices(cls_list, data_dict, num_sample_cls): + id_indices = [] + for _ in range(len(cls_list)): + cls_idx = next(cls_list) + for _ in range(num_sample_cls): + id = next(data_dict[cls_idx]) + id_indices.append(id) + return id_indices + + # deterministically shuffle based on epoch + num_bins = int( + math.ceil(self.total_size * 1.0 / self.class_num / + self.num_sample_class)) + indices = [] + for i in range(num_bins): + indices += gen_class_num_indices(label_iter_list, data_iter_dict, + self.num_sample_class) + + # fix extra samples to make it evenly divisible + if len(indices) >= self.total_size: + indices = indices[:self.total_size] + else: + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +class RandomCycleIter: + """Shuffle the list and do it again after the list have traversed. + + The implementation logic is referred to + https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py + + Example: + >>> label_list = [0, 1, 2, 4, 5] + >>> g = torch.Generator() + >>> g.manual_seed(0) + >>> label_iter_list = RandomCycleIter(label_list, generator=g) + >>> index = next(label_iter_list) + Args: + data (list or ndarray): The data that needs to be shuffled. + generator: An torch.Generator object, which is used in setting the seed + for generating random numbers. + """ # noqa: W605 + + def __init__(self, data, generator=None): + self.data = data + self.length = len(data) + self.index = torch.randperm(self.length, generator=generator).numpy() + self.i = 0 + self.generator = generator + + def __iter__(self): + return self + + def __len__(self): + return len(self.data) + + def __next__(self): + if self.i < self.length: + idx = self.data[self.index[self.i]] + self.i += 1 + return idx + else: + self.index = torch.randperm( + self.length, generator=self.generator).numpy() + self.i = 1 + idx = self.data[self.index[self.i - 1]] + return idx diff --git a/model-index.yml b/model-index.yml index e05ab8d2964..9f45cdf36ce 100644 --- a/model-index.yml +++ b/model-index.yml @@ -40,6 +40,7 @@ Import: - configs/ms_rcnn/metafile.yml - configs/nas_fcos/metafile.yml - configs/nas_fpn/metafile.yml + - configs/openimages/metafile.yml - configs/paa/metafile.yml - configs/pafpn/metafile.yml - configs/panoptic_fpn/metafile.yml diff --git a/tests/test_data/test_datasets/test_openimages_dataset.py b/tests/test_data/test_datasets/test_openimages_dataset.py index d1e17c23b04..f908e199fd1 100644 --- a/tests/test_data/test_datasets/test_openimages_dataset.py +++ b/tests/test_data/test_datasets/test_openimages_dataset.py @@ -45,6 +45,8 @@ def _create_oid_style_ann(label_file, csv_file, label_level_file): label_description = [['/m/000000', 'Sports equipment'], ['/m/000001', 'Ball'], ['/m/000002', 'Football'], ['/m/000004', 'Bicycle']] + # `newline=''` is used to avoid index error of out of bounds + # in Windows system with open(label_file, 'w', newline='') as f: f_csv = csv.writer(f) f_csv.writerows(label_description)