Skip to content

Commit

Permalink
Fix a bug related to MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 4b6afdd commit d0e82b3
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
if self._device == torch.device("mps") and scores.dtype == torch.double:
scores = scores.to(dtype=torch.float32)
self._scores.append(scores.to(self._device))
self._y_pred_labels.append(pred_labels.to(device=self._device))
self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device))

@sync_all_reduce("_y_true_count")
def _compute(self) -> torch.Tensor:
pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device)
pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device)
TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device)
FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device)
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
Expand Down

0 comments on commit d0e82b3

Please sign in to comment.