Skip to content

Commit

Permalink
[Fix] Force the inputs of get_bboxes in yolox_head to float32. (#7324)
Browse files Browse the repository at this point in the history
* Fix softnms bug

* Add force_fp32 in corner_head and centripetal_head
  • Loading branch information
jbwang1997 authored and ZwwWayne committed Mar 16, 2022
1 parent e76117b commit 62feea5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions mmdet/models/dense_heads/centripetal_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, normal_init
from mmcv.ops import DeformConv2d
from mmcv.runner import force_fp32

from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
Expand Down Expand Up @@ -203,6 +204,7 @@ def forward_single(self, x, lvl_ind):
]
return result_list

@force_fp32()
def loss(self,
tl_heats,
br_heats,
Expand Down Expand Up @@ -361,6 +363,7 @@ def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift,

return det_loss, off_loss, guiding_loss, centripetal_loss

@force_fp32()
def get_bboxes(self,
tl_heats,
br_heats,
Expand Down
5 changes: 4 additions & 1 deletion mmdet/models/dense_heads/corner_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, bias_init_with_prob
from mmcv.ops import CornerPool, batched_nms
from mmcv.runner import BaseModule
from mmcv.runner import BaseModule, force_fp32

from mmdet.core import multi_apply
from ..builder import HEADS, build_loss
Expand Down Expand Up @@ -152,6 +152,7 @@ def __init__(self,
self.train_cfg = train_cfg
self.test_cfg = test_cfg

self.fp16_enabled = False
self._init_layers()

def _make_layers(self, out_channels, in_channels=256, feat_channels=256):
Expand Down Expand Up @@ -509,6 +510,7 @@ def get_targets(self,

return target_result

@force_fp32()
def loss(self,
tl_heats,
br_heats,
Expand Down Expand Up @@ -649,6 +651,7 @@ def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off,

return det_loss, pull_loss, push_loss, off_loss

@force_fp32()
def get_bboxes(self,
tl_heats,
br_heats,
Expand Down
1 change: 1 addition & 0 deletions mmdet/models/dense_heads/yolox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def forward(self, feats):
self.multi_level_conv_reg,
self.multi_level_conv_obj)

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
def get_bboxes(self,
cls_scores,
bbox_preds,
Expand Down

0 comments on commit 62feea5

Please sign in to comment.