Skip to content

Commit

Permalink
Revert MPS fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 0444933 commit c433718
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ Complete list of metrics
MutualInformation
ObjectDetectionAvgPrecisionRecall
CommonObjDetectionMetrics
vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list
precision.Precision
PSNR
recall.Recall
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_average_precision_recall import (
coco_tensor_list_to_dict_list,
CommonObjDetectionMetrics,
ObjectDetectionAvgPrecisionRecall,
)
Expand Down Expand Up @@ -94,4 +95,5 @@
"MeanAveragePrecision",
"ObjectDetectionAvgPrecisionRecall",
"CommonObjDetectionMetrics",
"coco_tensor_list_to_dict_list",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -10,7 +9,7 @@
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce


def tensor_list_to_dict_list(
def coco_tensor_list_to_dict_list(
output: Tuple[
Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]],
Union[List[torch.Tensor], List[Dict[str, torch.Tensor]]],
Expand Down Expand Up @@ -83,8 +82,8 @@ def __init__(
max_detections_per_image_per_class: maximum number of detections per class in each image to consider
for evaluation. The most confident ones are selected.
output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s
``process_function``'s output into the form expected by the metric. An already provided example
is :func:`~ignite.metrics.vision.object_detection_average_precision_recall.tensor_list_to_dict_list`
``process_function``'s output into the form expected by the metric. An already provided example is
:func:`~ignite.metrics.vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list`
which accepts `y_pred` and `y` as lists of tensors and transforms them to the expected format.
Default is the identity function.
device: specifies which device updates are accumulated on. Setting the
Expand Down Expand Up @@ -235,10 +234,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
Returns:
average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
"""
mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0")
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback
rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1))
rec_thresh_indices = (
torch.searchsorted(recall, rec_thresholds)
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/vision/test_object_detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ignite.distributed as idist
from ignite.engine import Engine
from ignite.metrics import CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall
from ignite.metrics.vision.object_detection_average_precision_recall import tensor_list_to_dict_list
from ignite.metrics.vision.object_detection_average_precision_recall import coco_tensor_list_to_dict_list
from ignite.utils import manual_seed

torch.set_printoptions(linewidth=200)
Expand Down Expand Up @@ -957,7 +957,7 @@ def test_tensor_list_to_dict_list():
]
for y_pred in y_preds:
for y in ys:
y_pred_new, y_new = tensor_list_to_dict_list((y_pred, y))
y_pred_new, y_new = coco_tensor_list_to_dict_list((y_pred, y))
if isinstance(y_pred[0], dict):
assert y_pred_new is y_pred
else:
Expand Down

0 comments on commit c433718

Please sign in to comment.