Skip to content

Commit

Permalink
Fix a bug related to MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent d0e82b3 commit 085e0df
Showing 1 changed file with 4 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,13 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
except ImportError:
raise ModuleNotFoundError("This metric requires torchvision to be installed.")

precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32

if iou_thresholds is None:
iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision)
iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double)

self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds")
self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision)

if rec_thresholds is None:
rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision)

self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds")
self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision)
rec_thresholds = torch.linspace(0, 1, 101, dtype=torch.double)

self._num_classes = num_classes
self._area_range = area_range
Expand All @@ -130,6 +124,8 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
rec_thresholds=rec_thresholds,
class_mean=None,
)
precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32
self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision)

@reinit__is_reduced
def reset(self) -> None:
Expand Down

0 comments on commit 085e0df

Please sign in to comment.