diff --git a/mmrotate/core/post_processing/bbox_nms_rotated.py b/mmrotate/core/post_processing/bbox_nms_rotated.py index b8abcb93e..4affc2c9c 100644 --- a/mmrotate/core/post_processing/bbox_nms_rotated.py +++ b/mmrotate/core/post_processing/bbox_nms_rotated.py @@ -39,7 +39,7 @@ def multiclass_nms_rotated(multi_bboxes, multi_scores.size(0), num_classes, 5) scores = multi_scores[:, :-1] - labels = torch.arange(num_classes, dtype=torch.long) + labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) labels = labels.view(1, -1).expand_as(scores) bboxes = bboxes.reshape(-1, 5) scores = scores.reshape(-1)