Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task chain for Det -> Cls / Seg #4105

Merged
merged 10 commits into from
Nov 8, 2024
2 changes: 1 addition & 1 deletion src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _get_image_data_name(
subset = datamodule.subsets[subset_name]
item = subset.dm_subset[img_id]
img = item.media_as(Image)
img_data, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
img_data, _, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
image_save_name = "".join([char if char.isalnum() else "_" for char in item.id])
return img_data, image_save_name

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_item_impl(
datumaro_item = self.dm_subset[index]
img = datumaro_item.media_as(Image)
# returns image in RGB format if self.image_color_channel is RGB
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label = self._get_label(datumaro_item)

Expand Down
58 changes: 48 additions & 10 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -92,6 +92,7 @@
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image

if self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
Expand Down Expand Up @@ -141,11 +142,31 @@
msg = f"Reach the maximum refetch number ({self.max_refetch})"
raise RuntimeError(msg)

def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
def _get_img_data_and_shape(
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
self,
img: Image,
roi: dict[str, Any] | None = None,
) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]:
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
"""Get image data and shape.

This method is used to get image data and shape from Datumaro image object.
If ROI is provided, the image data is extracted from the ROI.

Args:
img (Image): Image object from Datumaro.
roi (dict[str, Any] | None, Optional): Region of interest.
Represented by dict with coordinates and some meta information.

Returns:
The image data, shape, and ROI meta information
"""
key = img.path if isinstance(img, ImageFromFile) else id(img)
roi_meta = None

if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
return img_data, img_data.shape[:2]
# check if the image is already in the cache
img_data, roi_meta = self.mem_cache_handler.get(key=key)
if img_data is not None:
return img_data, img_data.shape[:2], roi_meta

with image_decode_context():
img_data = (
Expand All @@ -158,11 +179,28 @@
msg = "Cannot get image data"
raise RuntimeError(msg)

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))
if roi:
# extract ROI from image
shape = roi["shape"]
h, w = img_data.shape[:2]
x1, y1, x2, y2 = (
kprokofi marked this conversation as resolved.
Show resolved Hide resolved
int(np.clip(np.trunc(shape["x1"] * w), 0, w)),
int(np.clip(np.trunc(shape["y1"] * h), 0, h)),
int(np.clip(np.ceil(shape["x2"] * w), 0, w)),
int(np.clip(np.ceil(shape["y2"] * h), 0, h)),
)
if (x2 - x1) * (y2 - y1) <= 0:
msg = f"ROI has zero or negative area. ROI coordinates: {x1}, {y1}, {x2}, {y2}"
raise ValueError(msg)

Check warning on line 194 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L193-L194

Added lines #L193 - L194 were not covered by tests

img_data = img_data[y1:y2, x1:x2]
roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)}

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta)

return img_data, img_data.shape[:2]
return img_data, img_data.shape[:2], roi_meta

def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray:
"""Cache an image after resizing.

If there is available space in the memory pool, the input image is cached.
Expand All @@ -182,14 +220,14 @@
return img_data

if self.mem_cache_img_max_size is None:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
return img_data

height, width = img_data.shape[:2]
max_height, max_width = self.mem_cache_img_max_size

if height <= max_height and width <= max_width:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)

Check warning on line 230 in src/otx/core/data/dataset/base.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/base.py#L230

Added line #L230 was not covered by tests
return img_data

# Preserve the image size ratio and fit to max_height or max_width
Expand All @@ -206,7 +244,7 @@
self.mem_cache_handler.put(
key=key,
data=resized_img,
meta=None,
meta=meta,
)
return resized_img

Expand Down
28 changes: 14 additions & 14 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
roi = item.attributes.get("roi", None)
img_data, img_shape, _ = self._get_img_data_and_shape(img, roi)
if roi:
# extract labels from ROI
labels_ids = [

Check warning on line 39 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L39

Added line #L39 was not covered by tests
label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION"
]
label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids]

Check warning on line 42 in src/otx/core/data/dataset/classification.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/classification.py#L42

Added line #L42 was not covered by tests
else:
# extract labels from annotations
label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)]

label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
if len(label_anns) > 1:
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
raise ValueError(msg)
Expand All @@ -56,7 +56,7 @@
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
labels=torch.as_tensor([ann.label for ann in label_anns]),
labels=torch.as_tensor(label_anns),
)

return self._apply_transforms(entity)
Expand All @@ -78,7 +78,7 @@
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
for ann in item.annotations:
Expand Down Expand Up @@ -195,7 +195,7 @@
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
for ann in item.annotations:
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
bboxes = (
Expand Down
9 changes: 7 additions & 2 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,14 @@
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)
roi = item.attributes.get("roi", None)
img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi)
if item.annotations:
extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)
ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape
extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index)
if roi_meta:
extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]]

Check warning on line 211 in src/otx/core/data/dataset/segmentation.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/segmentation.py#L211

Added line #L211 was not covered by tests

masks = tv_tensors.Mask(extracted_mask[None])
else:
# semi-supervised learning, unlabeled dataset
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@
"""
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

Check warning on line 373 in src/otx/core/data/dataset/tile.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/tile.py#L373

Added line #L373 was not covered by tests

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]

Expand Down Expand Up @@ -461,7 +461,7 @@
"""
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

Check warning on line 464 in src/otx/core/data/dataset/tile.py

View check run for this annotation

Codecov / codecov/patch

src/otx/core/data/dataset/tile.py#L464

Added line #L464 was not covered by tests

gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []

Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(dmImage)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_bboxes, gt_points = [], []
gt_masks = defaultdict(list)
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(dmImage)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_prompts: list[tvBoundingBoxes | Points] = []
gt_masks: list[tvMask] = []
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/core/data/dataset/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest import mock

import numpy as np
import pytest
from datumaro.components.media import Image
from otx.core.data.dataset.base import OTXDataset


class TestOTXDataset:
@pytest.fixture()
def mock_image(self) -> Image:
img = mock.Mock(spec=Image)
img.data = np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)
img.path = "test_path"
return img

@pytest.fixture()
def mock_mem_cache_handler(self):
mem_cache_handler = mock.MagicMock()
mem_cache_handler.frozen = False
return mem_cache_handler

@pytest.fixture()
def otx_dataset(self, mock_mem_cache_handler):
class MockOTXDataset(OTXDataset):
def _get_item_impl(self, idx: int) -> None:
return None

@property
def collate_fn(self) -> None:
return None

dm_subset = mock.Mock()
dm_subset.categories = mock.MagicMock()
dm_subset.categories.return_value = None

return MockOTXDataset(
dm_subset=dm_subset,
transforms=None,
mem_cache_handler=mock_mem_cache_handler,
mem_cache_img_max_size=None,
)

def test_get_img_data_and_shape_no_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
mock_mem_cache_handler.get.return_value = (None, None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
assert img_data.shape == (10, 10, 3)
assert img_shape == (10, 10)
assert roi_meta is None

def test_get_img_data_and_shape_with_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
mock_mem_cache_handler.get.return_value = (np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8), None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
assert img_data.shape == (10, 10, 3)
assert img_shape == (10, 10)
assert roi_meta is None

def test_get_img_data_and_shape_with_roi(self, otx_dataset, mock_image, mock_mem_cache_handler):
roi = {"shape": {"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}}
mock_mem_cache_handler.get.return_value = (None, None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image, roi)
assert img_data.shape == (8, 8, 3)
assert img_shape == (8, 8)
assert roi_meta == {"x1": 1, "y1": 1, "x2": 9, "y2": 9, "orig_image_shape": (10, 10)}

def test_cache_img_no_resize(self, otx_dataset):
img_data = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
otx_dataset.mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)

def test_cache_img_with_resize(self, otx_dataset, mock_mem_cache_handler):
otx_dataset.mem_cache_img_max_size = (100, 100)
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert cached_img.shape == (100, 100, 3)
mock_mem_cache_handler.put.assert_called_once()
assert mock_mem_cache_handler.put.call_args[1]["data"].shape == (100, 100, 3)

def test_cache_img_no_max_size(self, otx_dataset, mock_mem_cache_handler):
otx_dataset.mem_cache_img_max_size = None
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
mock_mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)

def test_cache_img_frozen_handler(self, otx_dataset, mock_mem_cache_handler):
mock_mem_cache_handler.frozen = True
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
mock_mem_cache_handler.put.assert_not_called()
Loading