Skip to content

Commit

Permalink
support fp16 and batchsize>1
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqi-li committed Jun 27, 2022
1 parent 68e6af5 commit 5a25c8b
Show file tree
Hide file tree
Showing 15 changed files with 801 additions and 12 deletions.
8 changes: 8 additions & 0 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,11 @@ Eval BEVFormer with 8 GPUs
Note: using 1 GPU to eval can obtain slightly higher performance because continuous video may be truncated with multiple GPUs. By default we report the score evaled with 8 GPUs.



# Using FP16 to train the model.

We provide another script to train BEVFormer with FP16.

```
./tools/fp16/dist_train.sh ./projects/configs/bevformer_fp16/bevformer_tiny_fp16.py 8
```
272 changes: 272 additions & 0 deletions projects/configs/bevformer_fp16/bevformer_tiny_fp16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# BEvFormer-tiny consumes at lease 6700M GPU memory
# compared to bevformer_base, bevformer_tiny has
# smaller backbone: R101-DCN -> R50
# smaller BEV: 200*200 -> 50*50
# less encoder layers: 6 -> 3
# smaller input size: 1600*900 -> 800*450
# multi-scale feautres -> single scale features (C5)


_base_ = [
'../datasets/custom_nus-3d.py',
'../_base_/default_runtime.py'
]
#
plugin = True
plugin_dir = 'projects/mmdet3d_plugin/'

# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
voxel_size = [0.2, 0.2, 8]




img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

# For nuScenes we usually do 10-class detection
class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]

input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=True)

_dim_ = 256
_pos_dim_ = _dim_//2
_ffn_dim_ = _dim_*2
_num_levels_ = 1
bev_h_ = 50
bev_w_ = 50
queue_length = 3 # each sequence contains `queue_length` frames.

model = dict(
type='BEVFormer_fp16',
use_grid_mask=True,
video_test_mode=True,
pretrained=dict(img='torchvision://resnet50'),
img_backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3,),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch'),
img_neck=dict(
type='FPN',
in_channels=[2048],
out_channels=_dim_,
start_level=0,
add_extra_convs='on_output',
num_outs=_num_levels_,
relu_before_extra_convs=True),
pts_bbox_head=dict(
type='BEVFormerHead',
bev_h=bev_h_,
bev_w=bev_w_,
num_query=900,
num_classes=10,
in_channels=_dim_,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
transformer=dict(
type='PerceptionTransformer',
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
embed_dims=_dim_,
encoder=dict(
type='BEVFormerEncoder',
num_layers=3,
pc_range=point_cloud_range,
num_points_in_pillar=4,
return_intermediate=False,
transformerlayers=dict(
type='BEVFormerLayer',
attn_cfgs=[
dict(
type='TemporalSelfAttention',
embed_dims=_dim_,
num_levels=1),
dict(
type='SpatialCrossAttention',
pc_range=point_cloud_range,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=_dim_,
num_points=8,
num_levels=_num_levels_),
embed_dims=_dim_,
)
],
feedforward_channels=_ffn_dim_,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))),
decoder=dict(
type='DetectionTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=_dim_,
num_heads=8,
dropout=0.1),
dict(
type='CustomMSDeformableAttention',
embed_dims=_dim_,
num_levels=1),
],

feedforward_channels=_ffn_dim_,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
bbox_coder=dict(
type='NMSFreeCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=point_cloud_range,
max_num=300,
voxel_size=voxel_size,
num_classes=10),
positional_encoding=dict(
type='LearnedPositionalEncoding',
num_feats=_pos_dim_,
row_num_embed=bev_h_,
col_num_embed=bev_w_,
),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=0.25),
loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
# model training and testing settings
train_cfg=dict(pts=dict(
grid_size=[512, 512, 1],
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
out_size_factor=4,
assigner=dict(
type='HungarianAssigner3D',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head.
pc_range=point_cloud_range))))

dataset_type = 'CustomNuScenesDataset'
data_root = 'data/nuscenes/'
file_client_args = dict(backend='disk')


train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='PhotoMetricDistortionMultiViewImage'),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
dict(type='PadMultiViewImage', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='CustomCollect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'])
]

test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),

dict(
type='MultiScaleFlipAug3D',
img_scale=(1600, 900),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
dict(type='PadMultiViewImage', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='CustomCollect3D', keys=['img'])
])
]

data = dict(
samples_per_gpu=2,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
pipeline=train_pipeline,
classes=class_names,
modality=input_modality,
test_mode=False,
use_valid_flag=True,
bev_size=(bev_h_, bev_w_),
queue_length=queue_length,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR'),
val=dict(type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
pipeline=test_pipeline, bev_size=(bev_h_, bev_w_),
classes=class_names, modality=input_modality, samples_per_gpu=1),
test=dict(type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
pipeline=test_pipeline, bev_size=(bev_h_, bev_w_),
classes=class_names, modality=input_modality),
shuffler_sampler=dict(type='DistributedGroupSampler'),
nonshuffler_sampler=dict(type='DistributedSampler')
)

optimizer = dict(
type='AdamW',
lr=2.8e-4,
paramwise_cfg=dict(
custom_keys={
'img_backbone': dict(lr_mult=0.1),
}),
weight_decay=0.01)

optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
min_lr_ratio=1e-3)
total_epochs = 24
evaluation = dict(interval=1, pipeline=test_pipeline)

runner = dict(type='EpochBasedRunner_video', max_epochs=total_epochs)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])

fp16 = dict(loss_scale=512.)
checkpoint_config = dict(interval=1)
custom_hooks = [dict(type='TransferWeight',priority='LOWEST')]
2 changes: 2 additions & 0 deletions projects/mmdet3d_plugin/bevformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .dense_heads import *
from .detectors import *
from .modules import *
from .runner import *
from .hooks import *
40 changes: 30 additions & 10 deletions projects/mmdet3d_plugin/bevformer/apis/mmdet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def custom_train_detector(model,
distributed=False,
validate=False,
timestamp=None,
eval_model=None,
meta=None):
logger = get_root_logger(cfg.log_level)

Expand Down Expand Up @@ -76,10 +77,19 @@ def custom_train_detector(model,
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)

if eval_model is not None:
eval_model = MMDistributedDataParallel(
eval_model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
if eval_model is not None:
eval_model = MMDataParallel(
eval_model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)


# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand All @@ -95,15 +105,25 @@ def custom_train_detector(model,
else:
if 'total_epochs' in cfg:
assert cfg.total_epochs == cfg.runner.max_epochs

runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
if eval_model is not None:
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
eval_model=eval_model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
else:
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))

# an ugly workaround to make .log and .log.json filenames the same
runner.timestamp = timestamp
Expand Down
2 changes: 2 additions & 0 deletions projects/mmdet3d_plugin/bevformer/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def custom_train_model(model,
distributed=False,
validate=False,
timestamp=None,
eval_model=None,
meta=None):
"""A function wrapper for launching model training according to cfg.
Expand All @@ -30,6 +31,7 @@ def custom_train_model(model,
distributed=distributed,
validate=validate,
timestamp=timestamp,
eval_model=eval_model,
meta=meta)


Expand Down
3 changes: 2 additions & 1 deletion projects/mmdet3d_plugin/bevformer/detectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .bevformer import BEVFormer
from .bevformer import BEVFormer
from .bevformer_fp16 import BEVFormer_fp16
2 changes: 1 addition & 1 deletion projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def obtain_history_bev(self, imgs_queue, img_metas_list):
self.train()
return prev_bev

@auto_fp16(apply_to=('img', 'prev_bev', 'points'))
@auto_fp16(apply_to=('img', 'points'))
def forward_train(self,
points=None,
img_metas=None,
Expand Down
Loading

0 comments on commit 5a25c8b

Please sign in to comment.