-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
anchor_head_single.py
75 lines (60 loc) · 2.86 KB
/
anchor_head_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
class AnchorHeadSingle(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True, **kwargs):
super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range,
predict_boxes_when_training=predict_boxes_when_training
)
self.num_anchors_per_location = sum(self.num_anchors_per_location)
self.conv_cls = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.num_class,
kernel_size=1
)
self.conv_box = nn.Conv2d(
input_channels, self.num_anchors_per_location * self.box_coder.code_size,
kernel_size=1
)
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
self.conv_dir_cls = nn.Conv2d(
input_channels,
self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
kernel_size=1
)
else:
self.conv_dir_cls = None
self.init_weights()
def init_weights(self):
pi = 0.01
nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
nn.init.normal_(self.conv_box.weight, mean=0, std=0.001)
def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d']
cls_preds = self.conv_cls(spatial_features_2d)
box_preds = self.conv_box(spatial_features_2d)
cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous() # [N, H, W, C]
box_preds = box_preds.permute(0, 2, 3, 1).contiguous() # [N, H, W, C]
self.forward_ret_dict['cls_preds'] = cls_preds
self.forward_ret_dict['box_preds'] = box_preds
if self.conv_dir_cls is not None:
dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
self.forward_ret_dict['dir_cls_preds'] = dir_cls_preds
else:
dir_cls_preds = None
if self.training:
targets_dict = self.assign_targets(
gt_boxes=data_dict['gt_boxes']
)
self.forward_ret_dict.update(targets_dict)
if not self.training or self.predict_boxes_when_training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=data_dict['batch_size'],
cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
)
data_dict['batch_cls_preds'] = batch_cls_preds
data_dict['batch_box_preds'] = batch_box_preds
data_dict['cls_preds_normalized'] = False
return data_dict