From a9555987207737b0c2de504fd55e9c137201e219 Mon Sep 17 00:00:00 2001 From: ahatamizadeh Date: Fri, 21 Jul 2023 10:39:27 -0700 Subject: [PATCH] add detection code Signed-off-by: ahatamizadeh --- README.md | 4 +- detection/README.md | 23 + .../configs/_base_/datasets/coco_detection.py | 50 ++ .../configs/_base_/datasets/coco_instance.py | 50 ++ .../_base_/datasets/coco_instance_semantic.py | 56 ++ detection/configs/_base_/default_runtime.py | 17 + .../_base_/models/cascade_mask_rcnn_gcvit.py | 205 +++++ .../configs/_base_/models/mask_rcnn_gcvit.py | 125 +++ .../configs/_base_/schedules/schedule_3x.py | 15 + .../cascade_mask_rcnn_gcvit_tiny_3x_coco.py | 129 +++ .../gcvit/mask_rcnn_gcvit_tiny_3x_coco.py | 79 ++ detection/models/gc_vit.py | 754 ++++++++++++++++++ detection/requirements.txt | 4 + detection/test.py | 223 ++++++ detection/train.py | 191 +++++ 15 files changed, 1923 insertions(+), 2 deletions(-) create mode 100644 detection/README.md create mode 100644 detection/configs/_base_/datasets/coco_detection.py create mode 100644 detection/configs/_base_/datasets/coco_instance.py create mode 100644 detection/configs/_base_/datasets/coco_instance_semantic.py create mode 100644 detection/configs/_base_/default_runtime.py create mode 100644 detection/configs/_base_/models/cascade_mask_rcnn_gcvit.py create mode 100644 detection/configs/_base_/models/mask_rcnn_gcvit.py create mode 100644 detection/configs/_base_/schedules/schedule_3x.py create mode 100644 detection/configs/gcvit/cascade_mask_rcnn_gcvit_tiny_3x_coco.py create mode 100644 detection/configs/gcvit/mask_rcnn_gcvit_tiny_3x_coco.py create mode 100644 detection/models/gc_vit.py create mode 100644 detection/requirements.txt create mode 100644 detection/test.py create mode 100644 detection/train.py diff --git a/README.md b/README.md index 4b8976c..e84ce39 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ The architecture of GC ViT is demonstrated in the following: ![teaser](./assets/gc_vit.png) ## 💥 News 💥 - -- **[05.21.2023]** 🔥🔥 We have released ImageNet-21K fine-tuned GC ViT model weights for 224x224 and 384x384. +- **[07.21.2023]** 🔥 We have released the object detection/instance segmentation ![code](./detection/README.md) ! +- **[05.21.2023]** 🔥 We have released ImageNet-21K fine-tuned GC ViT model weights for 224x224 and 384x384. - **[05.21.2023]** 🔥🔥 We have released new ImageNet-1K GC ViT model weights with **better performance** ! - **[04.24.2023]** 🔥🔥🔥 GC ViT has been accepted to **ICML 2023** ! diff --git a/detection/README.md b/detection/README.md new file mode 100644 index 0000000..2da98c6 --- /dev/null +++ b/detection/README.md @@ -0,0 +1,23 @@ +# GC ViT - Object Detection +This repository is the official PyTorch implementation of Global Context Vision Transformers for object detection using MS COCO dataset. + +## Requirements +The dependencies can be installed by running: + +```bash +pip install -r requirements.txt +``` + + +## Benchmarks + +The expected performance of models that use GC ViT as a backbone is listed below: + +| Backbone | Head | #Params(M) | FLOPs(G) | mAP | Mask mAP| +|---|---|---|---|---|---| +| GC ViT-T | Mask R-CNN | 48 | 291 | 47.9 | 43.2 | +| GC ViT-T | Cascade Mask R-CNN | 85 | 770 | 51.6 | 44.6 | +| GC ViT-S | Cascade Mask R-CNN | 108 | 866 | 52.4 | 45.4 | +| GC ViT-B | Cascade Mask R-CNN | 146 | 1018 | 52.9 | 45.8 | + + diff --git a/detection/configs/_base_/datasets/coco_detection.py b/detection/configs/_base_/datasets/coco_detection.py new file mode 100644 index 0000000..a5c02fc --- /dev/null +++ b/detection/configs/_base_/datasets/coco_detection.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +dataset_type = 'CocoDataset' +data_root = '/dataset/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric='bbox') diff --git a/detection/configs/_base_/datasets/coco_instance.py b/detection/configs/_base_/datasets/coco_instance.py new file mode 100644 index 0000000..c7bf518 --- /dev/null +++ b/detection/configs/_base_/datasets/coco_instance.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +dataset_type = 'CocoDataset' +data_root = '/dataset/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(metric=['bbox', 'segm']) diff --git a/detection/configs/_base_/datasets/coco_instance_semantic.py b/detection/configs/_base_/datasets/coco_instance_semantic.py new file mode 100644 index 0000000..7402627 --- /dev/null +++ b/detection/configs/_base_/datasets/coco_instance_semantic.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +dataset_type = 'CocoDataset' +data_root = '/dataset/' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='SegRescale', scale_factor=1 / 8), + dict(type='DefaultFormatBundle'), + dict( + type='Collect', + keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + seg_prefix=data_root + 'stuffthingmaps/train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(metric=['bbox', 'segm']) diff --git a/detection/configs/_base_/default_runtime.py b/detection/configs/_base_/default_runtime.py new file mode 100644 index 0000000..f1a75c1 --- /dev/null +++ b/detection/configs/_base_/default_runtime.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook') + ]) +# yapf:enable +custom_hooks = [dict(type='NumClassCheckHook')] + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/detection/configs/_base_/models/cascade_mask_rcnn_gcvit.py b/detection/configs/_base_/models/cascade_mask_rcnn_gcvit.py new file mode 100644 index 0000000..fd87718 --- /dev/null +++ b/detection/configs/_base_/models/cascade_mask_rcnn_gcvit.py @@ -0,0 +1,205 @@ +# model settings +model = dict( + type='CascadeRCNN', + pretrained=None, + backbone=dict( + type='GCViT', + dim=128, + mlp_ratio=3.0, + depths=[3, 4, 19, 5], + num_heads=[4, 8, 16, 32], + drop_path_rate=0.2, + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + frozen_stages=-1, + ), + neck=dict( + type='FPN', + in_channels=[64, 128, 256, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)), + roi_head=dict( + type='CascadeRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=[ + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg = dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg = dict( + rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) \ No newline at end of file diff --git a/detection/configs/_base_/models/mask_rcnn_gcvit.py b/detection/configs/_base_/models/mask_rcnn_gcvit.py new file mode 100644 index 0000000..4ac8227 --- /dev/null +++ b/detection/configs/_base_/models/mask_rcnn_gcvit.py @@ -0,0 +1,125 @@ +# model settings +model = dict( + type='MaskRCNN', + pretrained=None, + backbone=dict( + type='GCViT', + dim=128, + mlp_ratio=3.0, + depths=[3, 4, 19, 5], + num_heads=[4, 8, 16, 32], + drop_path_rate=0.2, + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + frozen_stages=-1, + ), + neck=dict( + type='FPN', + in_channels=[64, 128, 256, 512], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=80, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100, + mask_thr_binary=0.5))) \ No newline at end of file diff --git a/detection/configs/_base_/schedules/schedule_3x.py b/detection/configs/_base_/schedules/schedule_3x.py new file mode 100644 index 0000000..c007105 --- /dev/null +++ b/detection/configs/_base_/schedules/schedule_3x.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# optimizer +optimizer = dict(type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05) +optimizer_config = dict(grad_clip=None) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=2500, + warmup_ratio=0.001, + step=[27, 33]) + +runner = dict(type='EpochBasedRunner', max_epochs=36) diff --git a/detection/configs/gcvit/cascade_mask_rcnn_gcvit_tiny_3x_coco.py b/detection/configs/gcvit/cascade_mask_rcnn_gcvit_tiny_3x_coco.py new file mode 100644 index 0000000..b140493 --- /dev/null +++ b/detection/configs/gcvit/cascade_mask_rcnn_gcvit_tiny_3x_coco.py @@ -0,0 +1,129 @@ +_base_ = [ + '../_base_/models/cascade_mask_rcnn_gcvit.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + type='GCViT', + dim=64, + mlp_ratio=3.0, + depths=[3, 4, 19, 5], + num_heads=[2, 4, 8, 16], + drop_path_rate=0.2, + pretrained='./gcvit_tiny_best_1k.pth.tar' + ), + neck=dict(in_channels=[64, 128, 256, 512]), + roi_head=dict( + bbox_head=[ + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)), + dict( + type='ConvFCBBoxHead', + num_shared_convs=4, + num_shared_fcs=1, + in_channels=256, + conv_out_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=80, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + norm_cfg=dict(type='SyncBN', requires_grad=True), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=10.0)) + ])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='AutoAugment', + policies=[ + [ + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict(type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ] + ]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={ + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + }), + ) +lr_config = dict(step=[27, 33]) \ No newline at end of file diff --git a/detection/configs/gcvit/mask_rcnn_gcvit_tiny_3x_coco.py b/detection/configs/gcvit/mask_rcnn_gcvit_tiny_3x_coco.py new file mode 100644 index 0000000..9d6f3f5 --- /dev/null +++ b/detection/configs/gcvit/mask_rcnn_gcvit_tiny_3x_coco.py @@ -0,0 +1,79 @@ +_base_ = [ + '../_base_/models/mask_rcnn_gcvit.py', + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py' +] + +model = dict( + backbone=dict( + type='GCViT', + dim=64, + mlp_ratio=3.0, + depths=[3, 4, 19, 5], + num_heads=[2, 4, 8, 16], + drop_path_rate=0.2, + pretrained='./gcvit_tiny_best_1k.pth.tar' + ), + neck=dict(in_channels=[64, 128, 256, 512])) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# augmentation strategy originates from DETR / Sparse RCNN +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='AutoAugment', + policies=[ + [ + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + multiscale_mode='value', + keep_ratio=True) + ], + [ + dict(type='Resize', + img_scale=[(400, 1333), (500, 1333), (600, 1333)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict(type='Resize', + img_scale=[(480, 1333), (512, 1333), (544, 1333), + (576, 1333), (608, 1333), (640, 1333), + (672, 1333), (704, 1333), (736, 1333), + (768, 1333), (800, 1333)], + multiscale_mode='value', + override=True, + keep_ratio=True) + ] + ]), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05, + paramwise_cfg=dict(custom_keys={ + 'rpb': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + }), + ) +lr_config = dict(step=[27, 33]) +runner = dict(type='EpochBasedRunner', max_epochs=36) + +# Mixed precision +fp16 = None +optimizer_config = dict( + type="Fp16OptimizerHook", + grad_clip=None, + coalesce=True, + bucket_size_mb=-1, +) diff --git a/detection/models/gc_vit.py b/detection/models/gc_vit.py new file mode 100644 index 0000000..d9d0ce1 --- /dev/null +++ b/detection/models/gc_vit.py @@ -0,0 +1,754 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath, to_2tuple +from mmcv.runner import load_checkpoint +from mmdet.models.builder import BACKBONES +from mmdet.utils import get_root_logger +from torch.nn.functional import interpolate as interpolate +import torch.nn.functional as F + + +def _to_channel_last(x): + """ + Args: + x: (B, C, H, W) + + Returns: + x: (B, H, W, C) + """ + return x.permute(0, 2, 3, 1) + + +def _to_channel_first(x): + """ + Args: + x: (B, H, W, C) + + Returns: + x: (B, C, H, W) + """ + return x.permute(0, 3, 1, 2) + + +class SE(nn.Module): + """ + Squeeze and excitation block + """ + + def __init__(self, + inp, + oup, + expansion=0.25): + """ + Args: + inp: input features dimension. + oup: output features dimension. + expansion: expansion ratio. + """ + + super().__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(oup, int(inp * expansion), bias=False), + nn.GELU(), + nn.Linear(int(inp * expansion), oup, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class ReduceSize(nn.Module): + """ + Down-sampling block based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, + dim, + norm_layer=nn.LayerNorm, + keep_dim=False): + """ + Args: + dim: feature size dimension. + norm_layer: normalization layer. + keep_dim: bool argument for maintaining the resolution. + """ + + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(dim, dim, 3, 1, 1, + groups=dim, bias=False), + nn.GELU(), + SE(dim, dim), + nn.Conv2d(dim, dim, 1, 1, 0, bias=False), + ) + if keep_dim: + dim_out = dim + else: + dim_out = 2 * dim + self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False) + self.norm2 = norm_layer(dim_out) + self.norm1 = norm_layer(dim) + + def forward(self, x): + x = x.contiguous() + x = self.norm1(x) + x = _to_channel_first(x) + x = x + self.conv(x) + x = self.reduction(x) + x = _to_channel_last(x) + x = self.norm2(x) + return x + + +class PatchEmbed(nn.Module): + """ + Patch embedding block based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, in_chans=3, dim=96): + """ + Args: + in_chans: number of input channels. + dim: feature size dimension. + """ + + super().__init__() + self.proj = nn.Conv2d(in_chans, dim, 3, 2, 1) + self.conv_down = ReduceSize(dim=dim, keep_dim=True) + + def forward(self, x): + x = self.proj(x) + x = _to_channel_last(x) + x = self.conv_down(x) + return x + + +class FeatExtract(nn.Module): + """ + Feature extraction block based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, dim, keep_dim=False): + """ + Args: + dim: feature size dimension. + keep_dim: bool argument for maintaining the resolution. + """ + + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(dim, dim, 3, 1, 1, + groups=dim, bias=False), + nn.GELU(), + SE(dim, dim), + nn.Conv2d(dim, dim, 1, 1, 0, bias=False), + ) + if not keep_dim: + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.keep_dim = keep_dim + + def forward(self, x): + x = x.contiguous() + x = x + self.conv(x) + if not self.keep_dim: + x = self.pool(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class WindowAttention(nn.Module): + """ + Local window attention based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + """ + + def __init__(self, + dim, + num_heads, + window_size, + window_size_pre, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_rel_pos_bias=False + ): + """ + Args: + dim: feature size dimension. + num_heads: number of attention head. + window_size: window size. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + attn_drop: attention dropout rate. + proj_drop: output dropout rate. + """ + + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.use_rel_pos_bias = use_rel_pos_bias + self.window_size_pre = window_size_pre + self.window_size = (window_size, window_size) # Wh, Ww + window_size = (window_size_pre, window_size_pre) + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, q_global): + B_, N, C = x.shape + head_dim = C // self.num_heads + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.use_rel_pos_bias: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class WindowAttentionGlobal(nn.Module): + """ + Global window attention based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, + dim, + num_heads, + window_size, + window_size_pre, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_rel_pos_bias=False + ): + """ + Args: + dim: feature size dimension. + num_heads: number of attention head. + window_size: window size. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + attn_drop: attention dropout rate. + proj_drop: output dropout rate. + """ + + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.use_rel_pos_bias = use_rel_pos_bias + self.window_size_pre = window_size_pre + self.window_size = (window_size, window_size) # Wh, Ww + window_size = (window_size_pre, window_size_pre) + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, q_global): + B_, N, C = x.shape + B = q_global.shape[0] + head_dim = C // self.num_heads + B_dim = B_//B + kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, head_dim).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + q_global = q_global.repeat(1, B_dim, 1, 1, 1) + q = q_global.reshape(B_, self.num_heads, N, head_dim) + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + if self.use_rel_pos_bias: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class GCViTBlock(nn.Module): + """ + GCViT block based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size_pre, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + attention=WindowAttentionGlobal, + norm_layer=nn.LayerNorm, + layer_scale=None, + use_rel_pos_bias=False + ): + """ + Args: + dim: feature size dimension. + input_resolution: input image resolution. + num_heads: number of attention head. + window_size: window size. + mlp_ratio: MLP ratio. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: drop path rate. + act_layer: activation function. + attention: attention block type. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + """ + + super().__init__() + self.window_size = window_size + self.norm1 = norm_layer(dim) + + self.attn = attention(dim, + num_heads=num_heads, + window_size=window_size, + window_size_pre=window_size_pre, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + use_rel_pos_bias=use_rel_pos_bias + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.layer_scale = False + if layer_scale is not None and type(layer_scale) in [int, float]: + self.layer_scale = True + self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True) + else: + self.gamma1 = 1.0 + self.gamma2 = 1.0 + + def forward(self, x, q_global): + B, H, W, C = x.shape + shortcut = x + x = self.norm1(x) + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + shifted_x = x + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + _, h, w = x_windows.shape + attn_windows = self.attn(x_windows, q_global) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + x = shifted_x + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = shortcut + self.drop_path(self.gamma1 * x) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + +class GlobalQueryGen(nn.Module): + """ + Global query generator based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, + dim, + input_resolution, + image_resolution, + window_size, + num_heads): + """ + Args: + dim: feature size dimension. + input_resolution: input image resolution. + window_size: window size. + num_heads: number of heads. + + For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at + down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details. + """ + + super().__init__() + if input_resolution == image_resolution//4: + self.to_q_global = nn.Sequential( + FeatExtract(dim, keep_dim=False), + FeatExtract(dim, keep_dim=False), + FeatExtract(dim, keep_dim=False), + ) + + elif input_resolution == image_resolution//8: + self.to_q_global = nn.Sequential( + FeatExtract(dim, keep_dim=False), + FeatExtract(dim, keep_dim=False), + ) + + elif input_resolution == image_resolution//16: + + if window_size == input_resolution: + self.to_q_global = nn.Sequential( + FeatExtract(dim, keep_dim=True) + ) + + else: + self.to_q_global = nn.Sequential( + FeatExtract(dim, keep_dim=False) + ) + + elif input_resolution == image_resolution//32: + self.to_q_global = nn.Sequential( + FeatExtract(dim, keep_dim=True) + ) + + self.num_heads = num_heads + self.N = window_size * window_size + self.dim_head = dim // self.num_heads + self.window_size = window_size + + def forward(self, x): + x = self.to_q_global(x) + B, C, H, W = x.shape + if self.window_size != H or self.window_size !=W: + x = interpolate(x, size=(self.window_size, self.window_size), mode='bicubic') + x = _to_channel_last(x) + x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(0, 1, 3, 2, 4) + return x + + +class GCViTLayer(nn.Module): + """ + GCViT layer based on: "Hatamizadeh et al., + Global Context Vision Transformers " + """ + + def __init__(self, + dim, + depth, + input_resolution, + image_resolution, + num_heads, + window_size, + window_size_pre, + downsample=True, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + layer_scale=None, + use_rel_pos_bias=False): + """ + Args: + dim: feature size dimension. + depth: number of layers in each stage. + input_resolution: input image resolution. + window_size: window size in each stage. + downsample: bool argument for down-sampling. + mlp_ratio: MLP ratio. + num_heads: number of heads in each stage. + qkv_bias: bool argument for query, key, value learnable bias. + qk_scale: bool argument to scaling query, key. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: drop path rate. + norm_layer: normalization layer. + layer_scale: layer scaling coefficient. + """ + + super().__init__() + self.blocks = nn.ModuleList([ + GCViTBlock(dim=dim, + num_heads=num_heads, + window_size=window_size, + window_size_pre=window_size_pre, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attention=WindowAttention if (i % 2 == 0) else WindowAttentionGlobal, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + layer_scale=layer_scale, + input_resolution=input_resolution, + use_rel_pos_bias=use_rel_pos_bias) + for i in range(depth)]) + self.downsample = None if not downsample else ReduceSize(dim=dim, norm_layer=norm_layer) + self.q_global_gen = GlobalQueryGen(dim, input_resolution, image_resolution, window_size, num_heads) + + def forward(self, x): + q_global = self.q_global_gen(_to_channel_first(x)) + for blk in self.blocks: + x = blk(x, q_global) + + if self.downsample is None: + return x, x + return self.downsample(x), x + + + +@BACKBONES.register_module() +class GCViT(nn.Module): + def __init__(self, + dim, + depths, + mlp_ratio, + num_heads, + window_size=(24, 24, 24, 24), + window_size_pre=(7, 7, 14, 7), + resolution=224, + drop_path_rate=0.2, + in_chans=3, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + norm_layer=nn.LayerNorm, + layer_scale=None, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + pretrained=None, + use_rel_pos_bias=True, + **kwargs): + super().__init__() + + self.num_levels = len(depths) + self.embed_dim = dim + self.num_features = [int(dim * 2 ** i) for i in range(self.num_levels)] + self.mlp_ratio = mlp_ratio + self.pos_drop = nn.Dropout(p=drop_rate) + self.patch_embed = PatchEmbed(in_chans=in_chans, dim=dim) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.levels = nn.ModuleList() + for i in range(len(depths)): + level = GCViTLayer(dim=int(dim * 2 ** i), + depth=depths[i], + num_heads=num_heads[i], + window_size=window_size[i], + window_size_pre=window_size_pre[i], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + norm_layer=norm_layer, + downsample=(i < len(depths) - 1), + layer_scale=layer_scale, + input_resolution=int(2 ** (-2 - i) * resolution), + image_resolution=resolution, + use_rel_pos_bias=use_rel_pos_bias) + self.levels.append(level) + + # add a norm layer for each output + self.out_indices = out_indices + for i_layer in self.out_indices: + layer = norm_layer(self.num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self.frozen_stages = frozen_stages + if pretrained is not None: + self.init_weights(pretrained) + + for level in self.levels: + for block in level.blocks: + w_ = block.attn.window_size[0] + relative_position_bias_table_pre = block.attn.relative_position_bias_table + L1, nH1 = relative_position_bias_table_pre.shape + L2 = (2 * w_ - 1) * (2 * w_ - 1) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pre.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), + mode='bicubic') + relative_position_bias_table_pretrained_resized = relative_position_bias_table_pretrained_resized.view(nH1, L2).permute(1, 0) + block.attn.relative_position_bias_table = torch.nn.Parameter(relative_position_bias_table_pretrained_resized) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 2: + for i in range(0, self.frozen_stages - 1): + m = self.network[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(GCViT, self).train(mode) + self._freeze_stages() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + pass + else: + raise TypeError('pretrained must be a str or None') + + def forward_embeddings(self, x): + x = self.patch_embed(x) + return x + + def forward_tokens(self, x): + outs = [] + for idx, level in enumerate(self.levels): + x, xo = level(x) + if idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(xo) + outs.append(x_out.permute(0, 3, 1, 2).contiguous()) + return outs + + def forward(self, x): + x = self.forward_embeddings(x) + return self.forward_tokens(x) + + def forward_features(self, x): + x = self.forward_embeddings(x) + return self.forward_tokens(x) diff --git a/detection/requirements.txt b/detection/requirements.txt new file mode 100644 index 0000000..28ca64d --- /dev/null +++ b/detection/requirements.txt @@ -0,0 +1,4 @@ +timm==0.5.4 +pyyaml==6.0 +mmcv-full==1.4.8 +mmdet==2.19.0 diff --git a/detection/test.py b/detection/test.py new file mode 100644 index 0000000..5d4af13 --- /dev/null +++ b/detection/test.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.cnn import fuse_conv_bn +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) + +from mmdet.apis import multi_gpu_test, single_gpu_test +from mmdet.datasets import (build_dataloader, build_dataset, + replace_ImageToTensor) +from mmdet.models import build_detector, build_backbone + +from nat import * + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--fuse-conv-bn', + action='store_true', + help='Whether to fuse conv and bn, this will slightly increase' + 'the inference speed') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., "bbox",' + ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--show-score-thr', + type=float, + default=0.3, + help='score threshold (default: 0.3)') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='whether to use gpu to collect results.') + parser.add_argument( + '--tmpdir', + help='tmp directory used for collecting results from multiple ' + 'workers, available when gpu-collect is not specified') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function (deprecate), ' + 'change to --eval-options instead.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.eval_options: + raise ValueError( + '--options and --eval-options cannot be both ' + 'specified, --options is deprecated in favor of --eval-options') + if args.options: + warnings.warn('--options is deprecated in favor of --eval-options') + args.eval_options = args.options + return args + + +def main(): + args = parse_args() + + assert args.out or args.eval or args.format_only or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.model.pretrained = None + if cfg.model.get('neck'): + if isinstance(cfg.model.neck, list): + for neck_cfg in cfg.model.neck: + if neck_cfg.get('rfp_backbone'): + if neck_cfg.rfp_backbone.get('pretrained'): + neck_cfg.rfp_backbone.pretrained = None + elif cfg.model.neck.get('rfp_backbone'): + if cfg.model.neck.rfp_backbone.get('pretrained'): + cfg.model.neck.rfp_backbone.pretrained = None + + # in case the test dataset is concatenated + samples_per_gpu = 1 + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor( + cfg.data.test.pipeline) + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + samples_per_gpu = max( + [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) + if samples_per_gpu > 1: + for ds_cfg in cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + model.CLASSES = dataset.CLASSES + + if not distributed: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, + args.show_score_thr) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + print(f'\nwriting results to {args.out}') + mmcv.dump(outputs, args.out) + kwargs = {} if args.eval_options is None else args.eval_options + if args.format_only: + dataset.format_results(outputs, **kwargs) + if args.eval: + eval_kwargs = cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + print(dataset.evaluate(outputs, **eval_kwargs)) + + +if __name__ == '__main__': + main() diff --git a/detection/train.py b/detection/train.py new file mode 100644 index 0000000..6ee1235 --- /dev/null +++ b/detection/train.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import time +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import get_git_hash + +from mmdet import __version__ +from mmdet.apis import set_random_seed, train_detector +from mmdet.datasets import build_dataset +from mmdet.models import build_detector +from mmdet.utils import collect_env, get_root_logger +from models.gc_vit import * + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('--config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file (deprecate), ' + 'change to --cfg-options instead.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get('custom_imports', None): + from mmcv.utils import import_modules_from_strings + import_modules_from_strings(**cfg['custom_imports']) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + meta['exp_name'] = osp.basename(args.config) + cfg.device = 'cuda' + + model = build_detector( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], + CLASSES=datasets[0].CLASSES) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + nparams = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f'Number of parameters: {nparams}') + train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + main()