diff --git a/projects/easydeploy/nms/ort_nms.py b/projects/easydeploy/nms/ort_nms.py index aad93cf05..cc8145b36 100644 --- a/projects/easydeploy/nms/ort_nms.py +++ b/projects/easydeploy/nms/ort_nms.py @@ -29,7 +29,7 @@ def select_nms_index(scores: Tensor, batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) batched_labels = batched_labels.where( (batch_inds == batch_template.unsqueeze(1)), - batched_labels.new_ones(1) * -1) + batched_dets.new_ones(1, dtype=torch.int64) * -1) N = batched_dets.shape[0]