Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
JingweiZhang12 committed Oct 11, 2022
1 parent 9048654 commit 2d308d8
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions mmdet3d/datasets/det3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class Det3DDataset(BaseDataset):
which can be used in Evaluator. Defaults to True.
file_client_args (dict, optional): Configuration of file client.
Defaults to dict(backend='disk').
show_ins_num_var (bool, optional): Whether to show variance of the
number of instances before and after through pipeline. Defaults to
False.
show_ins_var (bool, optional): For debug purpose. Whether to show
variance of the number of instances before and after through
pipeline. Defaults to False.
"""

def __init__(self,
Expand All @@ -78,7 +78,7 @@ def __init__(self,
test_mode: bool = False,
load_eval_anns=True,
file_client_args: dict = dict(backend='disk'),
show_ins_num_var: bool = False,
show_ins_var: bool = False,
**kwargs) -> None:
# init file client
self.file_client = mmengine.FileClient(**file_client_args)
Expand Down Expand Up @@ -140,15 +140,16 @@ def __init__(self,

# used for showing variance of the number of instances before and
# after through the pipeline
self.show_ins_num_var = show_ins_num_var
self.show_ins_var = show_ins_var

# show statistics of this dataset
logger: MMLogger = MMLogger.get_current_instance()
logger.info('-' * 30)
logger.info(f'The length of the dataset: {len(self)}')
content_show = ''
for cat_name, num in self.ins_num_per_cat.items():
content_show += f'\n{cat_name}'.ljust(25) + f'{num}'
content_show += f'\n{cat_name}'.ljust(25) + '|' + f' {num}'.rjust(
10)
logger.info(
f'The number of instances per category in the dataset:{content_show}' # noqa: E501
)
Expand Down Expand Up @@ -345,7 +346,8 @@ def _show_ins_num_var(self, old_labels: np.ndarray,
content_show = ''
for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0)
content_show += f'\n{cat_name}'.ljust(25) + f'{new_num}/{num}'
content_show += f'\n{cat_name} '.ljust(
25) + '|' + f' {new_num}/{num}'.rjust(10)
logger.info('The number of instances per category after and before '
f'through pipeline: {content_show}')

Expand Down Expand Up @@ -384,7 +386,7 @@ def prepare_data(self, index: int) -> Optional[dict]:
example['data_samples'].gt_instances_3d.labels_3d) == 0:
return None

if self.show_ins_num_var:
if self.show_ins_var:
self._show_ins_num_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
Expand Down

0 comments on commit 2d308d8

Please sign in to comment.