Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent c433718 commit dacf407
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ Complete list of metrics
MultiLabelConfusionMatrix
MutualInformation
ObjectDetectionAvgPrecisionRecall
CommonObjDetectionMetrics
CommonObjectDetectionMetrics
vision.object_detection_average_precision_recall.coco_tensor_list_to_dict_list
precision.Precision
PSNR
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy
from ignite.metrics.vision.object_detection_average_precision_recall import (
coco_tensor_list_to_dict_list,
CommonObjDetectionMetrics,
CommonObjectDetectionMetrics,
ObjectDetectionAvgPrecisionRecall,
)

Expand Down Expand Up @@ -94,6 +94,6 @@
"ROC_AUC",
"MeanAveragePrecision",
"ObjectDetectionAvgPrecisionRecall",
"CommonObjDetectionMetrics",
"CommonObjectDetectionMetrics",
"coco_tensor_list_to_dict_list",
]
14 changes: 11 additions & 3 deletions ignite/metrics/vision/object_detection_average_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,15 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
Returns:
average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
"""
precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1)
if precision.device.type == "mps":
# Manual fallback to CPU if precision is on MPS due to the error:
# NotImplementedError: The operator 'aten::_cummax_helper' is not currently implemented for the MPS device
device = precision.device
precision_integrand = precision.flip(-1).cpu()
precision_integrand = precision_integrand.cummax(dim=-1).values
precision_integrand = precision_integrand.to(device=device).flip(-1)
else:
precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1)
rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1))
rec_thresh_indices = (
torch.searchsorted(recall, rec_thresholds)
Expand Down Expand Up @@ -386,9 +394,9 @@ def compute(self) -> Tuple[float, float]:
return ap, ar


class CommonObjDetectionMetrics(MetricGroup):
class CommonObjectDetectionMetrics(MetricGroup):
"""
Common Object detection metrics. Included metrics are as follows:
Common Object Detection metrics. Included metrics are as follows:
=============== ==========================================
**Metric name** **Description**
Expand Down
6 changes: 3 additions & 3 deletions tests/ignite/metrics/vision/test_object_detection_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.metrics import CommonObjDetectionMetrics, ObjectDetectionAvgPrecisionRecall
from ignite.metrics import CommonObjectDetectionMetrics, ObjectDetectionAvgPrecisionRecall
from ignite.metrics.vision.object_detection_average_precision_recall import coco_tensor_list_to_dict_list
from ignite.utils import manual_seed

Expand Down Expand Up @@ -895,7 +895,7 @@ def test_compute(sample):
print(all_res)
assert np.allclose(all_res, sample.mAP)

common_metrics = CommonObjDetectionMetrics(device=device)
common_metrics = CommonObjectDetectionMetrics(device=device)
common_metrics.update(sample.data)
res = common_metrics.compute()
common_metrics_res = [
Expand Down Expand Up @@ -1021,7 +1021,7 @@ def test_distrib_update_compute(distributed, sample):
all_res = [AP_50_95, AP_50, AP_75, AP_S, AP_M, AP_L, AR_1, AR_10, AR_100, AR_S, AR_M, AR_L]
assert np.allclose(all_res, sample.mAP)

common_metrics = CommonObjDetectionMetrics(device=device)
common_metrics = CommonObjectDetectionMetrics(device=device)
common_metrics.update((y_pred_rank, y_rank))
res = common_metrics.compute()
common_metrics_res = [
Expand Down

0 comments on commit dacf407

Please sign in to comment.