Skip to content

Commit

Permalink
Fix classwise computation in IoU metric (#1924)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit eeb40e9)
  • Loading branch information
SkafteNicki authored and Borda committed Aug 28, 2023
1 parent 1525a28 commit 284205f
Show file tree
Hide file tree
Showing 17 changed files with 624 additions and 650 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)


- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924))


- Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028)


## [1.1.0] - 2023-08-22

### Added
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
export FREEZE_REQUIREMENTS=1
# assume you have installed need packages
export SPHINX_MOCK_REQUIREMENTS=1
export SPHINX_FETCH_ASSETS=0

clean:
# clean all temp runs
Expand Down
14 changes: 8 additions & 6 deletions src/torchmetrics/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.
Expand All @@ -48,14 +46,14 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground
truth boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed ground truth
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``ciou_dict``: A dictionary containing the following key-values:
- ciou: (:class:`~torch.Tensor`)
- ciou: (:class:`~torch.Tensor`) with overall ciou value over all classes and samples.
- ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``
Args:
Expand All @@ -65,6 +63,9 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -86,7 +87,7 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = CompleteIntersectionOverUnion()
>>> metric(preds, target)
{'ciou': tensor(-0.5694)}
{'ciou': tensor(0.8611)}
Raises:
ModuleNotFoundError:
Expand All @@ -105,14 +106,15 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.
Expand All @@ -55,7 +53,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
- ``diou_dict``: A dictionary containing the following key-values:
- diou: (:class:`~torch.Tensor`)
- diou: (:class:`~torch.Tensor`) with overall diou value over all classes and samples.
- diou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True``
Args:
Expand All @@ -65,6 +63,9 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -86,7 +87,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = DistanceIntersectionOverUnion()
>>> metric(preds, target)
{'diou': tensor(-0.0694)}
{'diou': tensor(0.8611)}
Raises:
ModuleNotFoundError:
Expand All @@ -105,14 +106,15 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
Expand Down
12 changes: 7 additions & 5 deletions src/torchmetrics/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
- ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes``
detection boxes of the format specified in the constructor.
By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates.
- ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores
for the boxes.
- ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection
classes for the boxes.
Expand All @@ -55,7 +53,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
- ``giou_dict``: A dictionary containing the following key-values:
- giou: (:class:`~torch.Tensor`)
- giou: (:class:`~torch.Tensor`) with overall giou value over all classes and samples.
- giou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True``
Args:
Expand All @@ -65,6 +63,9 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored.
class_metrics:
Option to enable per-class metrics for IoU. Has a performance impact.
respect_labels:
Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou
between all pairs of boxes.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand All @@ -86,7 +87,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion):
... ]
>>> metric = GeneralizedIntersectionOverUnion()
>>> metric(preds, target)
{'giou': tensor(-0.0694)}
{'giou': tensor(0.8613)}
Raises:
ModuleNotFoundError:
Expand All @@ -105,9 +106,10 @@ def __init__(
box_format: str = "xyxy",
iou_threshold: Optional[float] = None,
class_metrics: bool = False,
respect_labels: bool = True,
**kwargs: Any,
) -> None:
super().__init__(box_format, iou_threshold, class_metrics, **kwargs)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

@staticmethod
def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor:
Expand Down
7 changes: 5 additions & 2 deletions src/torchmetrics/detection/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _input_validator(
preds: Sequence[Dict[str, Tensor]],
targets: Sequence[Dict[str, Tensor]],
iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox",
ignore_score: bool = False,
) -> None:
"""Ensure the correct input format of `preds` and `targets`."""
if isinstance(iou_type, str):
Expand All @@ -39,7 +40,7 @@ def _input_validator(
f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}"
)

for k in [*item_val_name, "scores", "labels"]:
for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []):
if any(k not in p for p in preds):
raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key")

Expand All @@ -50,7 +51,7 @@ def _input_validator(
for ivn in item_val_name:
if any(type(pred[ivn]) is not Tensor for pred in preds):
raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor")
if any(type(pred["scores"]) is not Tensor for pred in preds):
if not ignore_score and any(type(pred["scores"]) is not Tensor for pred in preds):
raise ValueError("Expected all scores in `preds` to be of type Tensor")
if any(type(pred["labels"]) is not Tensor for pred in preds):
raise ValueError("Expected all labels in `preds` to be of type Tensor")
Expand All @@ -67,6 +68,8 @@ def _input_validator(
f"Input '{ivn}' and labels of sample {i} in targets have a"
f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})"
)
if ignore_score:
return
for i, item in enumerate(preds):
for ivn in item_val_name:
if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)):
Expand Down
Loading

0 comments on commit 284205f

Please sign in to comment.