Skip to content

Commit

Permalink
Fix label compare of distance method (#1205)
Browse files Browse the repository at this point in the history
- Fix label comparison result of distance method
Signed-off-by: Kim, Vinnam <[email protected]>
Co-authored-by: Vinnam Kim <[email protected]>
  • Loading branch information
sooahleex authored Nov 28, 2023
1 parent 27e3804 commit 4b9f55f
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(<https://github.com/openvinotoolkit/datumaro/pull/1202>, <https://github.com/openvinotoolkit/datumaro/pull/1207>)
- Update document to correct wrong `datum project import` command and add filtering example to filter out items containing annotations.
(<https://github.com/openvinotoolkit/datumaro/pull/1210>)
- Fix label compare of distance method
(<https://github.com/openvinotoolkit/datumaro/pull/1205>)

## 16/11/2023 - Release 1.5.1
### Enhancements
Expand Down
27 changes: 11 additions & 16 deletions src/datumaro/cli/util/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import warnings
from collections import Counter
from enum import Enum, auto
from itertools import zip_longest
from typing import TYPE_CHECKING, Union

import cv2
Expand Down Expand Up @@ -83,22 +82,18 @@ def save(self, a: IDataset, b: IDataset):
if len(a) != len(b):
print("Datasets have different lengths: %s vs %s" % (len(a), len(b)))

a_classes = a.categories().get(AnnotationType.label, LabelCategories())
b_classes = b.categories().get(AnnotationType.label, LabelCategories())
class_mismatch = [
(idx, a_cls, b_cls)
for idx, (a_cls, b_cls) in enumerate(zip_longest(a_classes, b_classes))
if getattr(a_cls, "name", None) != getattr(b_cls, "name", None)
]
if class_mismatch:
a_classes = set(a.get_label_cat_names())
b_classes = set(b.get_label_cat_names())

if a_classes ^ b_classes:
print("Datasets have mismatching labels:")
for idx, a_class, b_class in class_mismatch:
if a_class and b_class:
print(" #%s: %s != %s" % (idx, a_class.name, b_class.name))
elif a_class:
print(" #%s: > %s" % (idx, a_class.name))
else:
print(" #%s: < %s" % (idx, b_class.name))

for idx, diff in enumerate(a_classes - b_classes):
print(" #%s: > %s" % (idx, diff))

for idx, diff in enumerate(b_classes - a_classes, start=len((a_classes - b_classes))):
print(" #%s: < %s" % (idx, diff))

self._a_classes = a.categories().get(AnnotationType.label)
self._b_classes = b.categories().get(AnnotationType.label)

Expand Down
6 changes: 6 additions & 0 deletions src/datumaro/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ def get_datasetitem_by_path(self, path):
path = osp.join(self._source_path, path)
return self._data.get_datasetitem_by_path(path)

def get_label_cat_names(self):
return [
label.name
for label in self._data.categories().get(AnnotationType.label, LabelCategories())
]

def get_subset_info(self) -> str:
return (
f"{subset_name}: # of items={len(self.get_subset(subset_name))}, "
Expand Down
32 changes: 19 additions & 13 deletions tests/integration/cli/test_compare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2021-2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import os
import os.path as osp
from unittest import TestCase

import numpy as np

Expand Down Expand Up @@ -29,10 +31,10 @@
from tests.utils.test_utils import run_datum as run


class CompareTest(TestCase):
class CompareTest:
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_compare_projects(self): # just a smoke test
label_categories1 = LabelCategories.from_iterable(["x", "a", "b", "y"])
def test_can_compare_projects(self, capsys): # just a smoke test
label_categories1 = LabelCategories.from_iterable(["x", "a", "b", "y", "z"])
mask_categories1 = MaskCategories.generate(len(label_categories1))

point_categories1 = PointsCategories()
Expand Down Expand Up @@ -101,7 +103,7 @@ def test_can_compare_projects(self): # just a smoke test
},
)

label_categories2 = LabelCategories.from_iterable(["a", "b", "x", "y"])
label_categories2 = LabelCategories.from_iterable(["a", "b", "c", "x", "y"])
mask_categories2 = MaskCategories.generate(len(label_categories2))

point_categories2 = PointsCategories()
Expand Down Expand Up @@ -177,10 +179,15 @@ def test_can_compare_projects(self): # just a smoke test
) as visualizer:
visualizer.save(dataset1, dataset2)

self.assertNotEqual(0, os.listdir(osp.join(test_dir)))
expected_output1 = "> z"
expected_output2 = "< c"
captured = capsys.readouterr()
assert expected_output1 in captured.out
assert expected_output2 in captured.out
assert 0 != os.listdir(osp.join(test_dir))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_run_distance_diff(self):
def test_can_run_distance_diff(self, helper_tc):
dataset1 = Dataset.from_iterable(
[
DatasetItem(
Expand Down Expand Up @@ -219,7 +226,7 @@ def test_can_run_distance_diff(self):

result_dir = osp.join(test_dir, "cmp_result")
run(
self,
helper_tc,
"compare",
dataset1_url + ":coco",
dataset2_url + ":voc",
Expand All @@ -228,11 +235,10 @@ def test_can_run_distance_diff(self):
"-o",
result_dir,
)

self.assertEqual({"bbox_confusion.png", "train"}, set(os.listdir(result_dir)))
assert {"bbox_confusion.png", "train"} == set(os.listdir(result_dir))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_run_equality_diff(self):
def test_can_run_equality_diff(self, helper_tc):
dataset1 = Dataset.from_iterable(
[
DatasetItem(
Expand Down Expand Up @@ -271,7 +277,7 @@ def test_can_run_equality_diff(self):

result_dir = osp.join(test_dir, "cmp_result")
run(
self,
helper_tc,
"compare",
dataset1_url + ":coco",
dataset2_url + ":voc",
Expand All @@ -281,4 +287,4 @@ def test_can_run_equality_diff(self):
result_dir,
)

self.assertEqual({"equality_compare.json"}, set(os.listdir(result_dir)))
assert {"equality_compare.json"} == set(os.listdir(result_dir))
17 changes: 17 additions & 0 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,23 @@ def test_can_check_media_type_on_caching(self):
with self.assertRaises(MediaTypeError):
dataset.init_cache()

@mark_requirement(Requirements.DATUM_GENERIC_MEDIA)
def test_get_label_cat_names(self):
dataset = Dataset.from_iterable(
[
DatasetItem(
id=100,
subset="train",
media=Image.from_numpy(data=np.ones((10, 6, 3))),
annotations=[
Bbox(1, 2, 3, 4, label=1),
],
),
],
categories=["a", "b", "c"],
)
self.assertEqual(dataset.get_label_cat_names(), ["a", "b", "c"])


class DatasetItemTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down

0 comments on commit 4b9f55f

Please sign in to comment.