Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support FCAF3D on S3DIS dataset in dev-1.x branch #1984

Merged
merged 22 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions configs/_base_/datasets/s3dis-3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# dataset settings
dataset_type = 'S3DISDataset'
data_root = 'data/s3dis/'

metainfo = dict(classes=('table', 'chair', 'sofa', 'bookcase', 'board'))
train_area = [1, 2, 3, 4, 6]
test_area = 5

train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='PointSample', num_points=100000),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.087266, 0.087266],
scale_ratio_range=[0.9, 1.1],
translation_std=[.1, .1, .1],
shift_height=False),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='PointSample', num_points=100000),
dict(type='NormalizePointsColor', color_mean=None),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]

train_dataloader = dict(
batch_size=8,
num_workers=4,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=13,
dataset=dict(
type='ConcatDataset',
datasets=[
dict(
type=dataset_type,
data_root=data_root,
ann_file=f's3dis_infos_Area_{i}.pkl',
pipeline=train_pipeline,
filter_empty_gt=True,
metainfo=metainfo,
box_type_3d='Depth') for i in train_area
])))

val_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=f's3dis_infos_Area_{test_area}.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='Depth'))
test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=f's3dis_infos_Area{test_area}.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
test_mode=True,
box_type_3d='Depth'))
val_evaluator = dict(type='IndoorMetric')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
27 changes: 27 additions & 0 deletions configs/fcaf3d/fcaf3d_2xb8_s3dis-3d-5class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = [
'../_base_/models/fcaf3d.py', '../_base_/default_runtime.py',
'../_base_/datasets/s3dis-3d.py'
]

model = dict(bbox_head=dict(num_classes=5))

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.001, weight_decay=0.0001),
clip_grad=dict(max_norm=10, norm_type=2))

# learning rate
param_scheduler = dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)

custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)]

# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=12)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
43 changes: 22 additions & 21 deletions mmdet3d/datasets/convert_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import List, Optional, Tuple, Union

import numpy as np
from nuscenes import NuScenes
from nuscenes.utils.geometry_utils import view_points
from pyquaternion import Quaternion
from shapely.geometry import MultiPoint, box
Expand Down Expand Up @@ -53,19 +53,20 @@
}


def get_nuscenes_2d_boxes(nusc, sample_data_token: str,
visibilities: List[str]):
"""Get the 2d / mono3d annotation records for a given `sample_data_token of
nuscenes dataset.
def get_nuscenes_2d_boxes(nusc: NuScenes, sample_data_token: str,
visibilities: List[str]) -> List[dict]:
"""Get the 2d / mono3d annotation records for a given `sample_data_token`
of nuscenes dataset.

Args:
nusc (:obj:`NuScenes`): NuScenes class.
sample_data_token (str): Sample data token belonging to a camera
keyframe.
visibilities (list[str]): Visibility filter.
visibilities (List[str]): Visibility filter.

Return:
list[dict]: List of 2d annotation record that belongs to the input
`sample_data_token`.
List[dict]: List of 2d annotation record that belongs to the input
`sample_data_token`.
"""

# Get the sample data and the sample corresponding to that sample data.
Expand Down Expand Up @@ -190,7 +191,7 @@ def get_kitti_style_2d_boxes(info: dict,
occluded: Tuple[int] = (0, 1, 2, 3),
annos: Optional[dict] = None,
mono3d: bool = True,
dataset: str = 'kitti'):
dataset: str = 'kitti') -> List[dict]:
"""Get the 2d / mono3d annotation records for a given info.

This function is used to get 2D/Mono3D annotations when loading annotations
Expand All @@ -202,7 +203,7 @@ def get_kitti_style_2d_boxes(info: dict,
belong to. In KITTI, typically only CAM 2 will be used,
and in Waymo, multi cameras could be used.
Defaults to 2.
occluded (tuple[int]): Integer (0, 1, 2, 3) indicating occlusion state:
occluded (Tuple[int]): Integer (0, 1, 2, 3) indicating occlusion state:
0 = fully visible, 1 = partly occluded, 2 = largely occluded,
3 = unknown, -1 = DontCare.
Defaults to (0, 1, 2, 3).
Expand All @@ -213,8 +214,8 @@ def get_kitti_style_2d_boxes(info: dict,
Defaults to 'kitti'.

Return:
list[dict]: List of 2d / mono3d annotation record that
belongs to the input camera id.
List[dict]: List of 2d / mono3d annotation record that
belongs to the input camera id.
"""
# Get calibration information
camera_intrinsic = info['calib'][f'P{cam_idx}']
Expand Down Expand Up @@ -336,20 +337,20 @@ def convert_annos(info: dict, cam_idx: int) -> dict:


def post_process_coords(
corner_coords: List, imsize: Tuple[int, int] = (1600, 900)
) -> Union[Tuple[float, float, float, float], None]:
corner_coords: List[int], imsize: Tuple[int] = (1600, 900)
) -> Union[Tuple[float], None]:
"""Get the intersection of the convex hull of the reprojected bbox corners
and the image canvas, return None if no intersection.

Args:
corner_coords (list[int]): Corner coordinates of reprojected
corner_coords (List[int]): Corner coordinates of reprojected
bounding box.
imsize (tuple[int]): Size of the image canvas.
imsize (Tuple[int]): Size of the image canvas.
Defaults to (1600, 900).

Return:
tuple[float]: Intersection of the convex hull of the 2D box
corners and the image canvas.
Tuple[float] or None: Intersection of the convex hull of the 2D box
corners and the image canvas.
"""
polygon_from_2d_box = MultiPoint(corner_coords).convex_hull
img_canvas = box(0, 0, imsize[0], imsize[1])
Expand All @@ -370,7 +371,7 @@ def post_process_coords(


def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
dataset: str) -> OrderedDict:
dataset: str) -> Union[dict, None]:
"""Generate one 2D annotation record given various information on top of
the 2D bounding box coordinates.

Expand All @@ -383,11 +384,11 @@ def generate_record(ann_rec: dict, x1: float, y1: float, x2: float, y2: float,
dataset (str): Name of dataset.

Returns:
dict: A sample 2d annotation record.
dict or None: A sample 2d annotation record.

- bbox_label (int): 2d box label id
- bbox_label_3d (int): 3d box label id
- bbox (list[float]): left x, top y, right x, bottom y of 2d box
- bbox (List[float]): left x, top y, right x, bottom y of 2d box
- bbox_3d_isvalid (bool): whether the box is valid
"""

Expand Down
21 changes: 11 additions & 10 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ class Det3DDataset(BaseDataset):
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='velodyne', img='').
pipeline (list[dict]): Pipeline used for data processing.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input,
it usually has following keys:

- use_camera: bool
- use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)`
Defaults to dict(use_lidar=True, use_camera=False).
default_cam_key (str, optional): The default camera name adopted.
Defaults to None.
box_type_3d (str): Type of 3D box of this dataset.
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self,
box_type_3d: dict = 'LiDAR',
filter_empty_gt: bool = True,
test_mode: bool = False,
load_eval_anns=True,
load_eval_anns: bool = True,
file_client_args: dict = dict(backend='disk'),
show_ins_var: bool = False,
**kwargs) -> None:
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self,
def _remove_dontcare(self, ann_info: dict) -> dict:
"""Remove annotations that do not need to be cared.

-1 indicate dontcare in MMDet3d.
-1 indicates dontcare in MMDet3d.

Args:
ann_info (dict): Dict of annotation infos. The
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_ann_info(self, index: int) -> dict:
index (int): Index of the annotation data to get.

Returns:
dict: annotation information.
dict: Annotation information.
"""
data_info = self.get_data_info(index)
# test model
Expand All @@ -197,7 +197,7 @@ def get_ann_info(self, index: int) -> dict:

return ann_info

def parse_ann_info(self, info: dict) -> Optional[dict]:
def parse_ann_info(self, info: dict) -> Union[dict, None]:
"""Process the `instances` in data info to `ann_info`.

In `Custom3DDataset`, we simply concatenate all the field
Expand All @@ -209,7 +209,7 @@ def parse_ann_info(self, info: dict) -> Optional[dict]:
info (dict): Info dict.

Returns:
dict | None: Processed `ann_info`
dict or None: Processed `ann_info`.
"""
# add s or gt prefix for most keys after concat
# we only process 3d annotations here, the corresponding
Expand Down Expand Up @@ -327,7 +327,8 @@ def parse_data_info(self, info: dict) -> dict:

return info

def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
def _show_ins_var(self, old_labels: np.ndarray,
new_labels: torch.Tensor) -> None:
"""Show variation of the number of instances before and after through
the pipeline.

Expand Down Expand Up @@ -356,7 +357,7 @@ def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
'The number of instances per category after and before '
f'through pipeline:\n{table.table}', 'current')

def prepare_data(self, index: int) -> Optional[dict]:
def prepare_data(self, index: int) -> Union[dict, None]:
"""Data preparation for both training and testing stage.

Called by `__getitem__` of dataset.
Expand All @@ -365,7 +366,7 @@ def prepare_data(self, index: int) -> Optional[dict]:
index (int): Index for accessing the target data.

Returns:
dict | None: Data dict of the corresponding index.
dict or None: Data dict of the corresponding index.
"""
ori_input_dict = self.get_data_info(index)

Expand Down
6 changes: 3 additions & 3 deletions mmdet3d/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class KittiDataset(Det3DDataset):
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict]): Pipeline used for data processing.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to `dict(use_lidar=True)`.
Defaults to dict(use_lidar=True).
default_cam_key (str): The default camera name adopted.
Defaults to 'CAM2'.
box_type_3d (str): Type of 3D box of this dataset.
Expand All @@ -38,7 +38,7 @@ class KittiDataset(Det3DDataset):
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
pcd_limit_range (list[float]): The range of point cloud used to filter
pcd_limit_range (List[float]): The range of point cloud used to filter
invalid predicted boxes.
Defaults to [0, -40, -3, 70.4, 40, 0.0].
"""
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/datasets/lyft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LyftDataset(Det3DDataset):
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict]): Pipeline used for data processing.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=False, use_lidar=True).
Expand Down
Loading