diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index 434d2612cf1..210d6aad0dd 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -225,7 +225,7 @@ def _get_image_data_name( subset = datamodule.subsets[subset_name] item = subset.dm_subset[img_id] img = item.media_as(Image) - img_data, _ = subset._get_img_data_and_shape(img) # noqa: SLF001 + img_data, _, _ = subset._get_img_data_and_shape(img) # noqa: SLF001 image_save_name = "".join([char if char.isalnum() else "_" for char in item.id]) return img_data, image_save_name diff --git a/src/otx/core/data/dataset/anomaly.py b/src/otx/core/data/dataset/anomaly.py index ec9b59ce499..0f855f5b3d6 100644 --- a/src/otx/core/data/dataset/anomaly.py +++ b/src/otx/core/data/dataset/anomaly.py @@ -79,7 +79,7 @@ def _get_item_impl( datumaro_item = self.dm_subset[index] img = datumaro_item.media_as(Image) # returns image in RGB format if self.image_color_channel is RGB - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label = self._get_label(datumaro_item) diff --git a/src/otx/core/data/dataset/base.py b/src/otx/core/data/dataset/base.py index a98f7c6083b..21e8d349a9c 100644 --- a/src/otx/core/data/dataset/base.py +++ b/src/otx/core/data/dataset/base.py @@ -8,7 +8,7 @@ from abc import abstractmethod from collections.abc import Iterable from contextlib import contextmanager -from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union import cv2 import numpy as np @@ -92,6 +92,7 @@ def __init__( self.image_color_channel = image_color_channel self.stack_images = stack_images self.to_tv_image = to_tv_image + if self.dm_subset.categories(): self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label]) else: @@ -141,11 +142,31 @@ def __getitem__(self, index: int) -> T_OTXDataEntity: msg = f"Reach the maximum refetch number ({self.max_refetch})" raise RuntimeError(msg) - def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]: + def _get_img_data_and_shape( + self, + img: Image, + roi: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]: + """Get image data and shape. + + This method is used to get image data and shape from Datumaro image object. + If ROI is provided, the image data is extracted from the ROI. + + Args: + img (Image): Image object from Datumaro. + roi (dict[str, Any] | None, Optional): Region of interest. + Represented by dict with coordinates and some meta information. + + Returns: + The image data, shape, and ROI meta information + """ key = img.path if isinstance(img, ImageFromFile) else id(img) + roi_meta = None - if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None: - return img_data, img_data.shape[:2] + # check if the image is already in the cache + img_data, roi_meta = self.mem_cache_handler.get(key=key) + if img_data is not None: + return img_data, img_data.shape[:2], roi_meta with image_decode_context(): img_data = ( @@ -158,11 +179,28 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in msg = "Cannot get image data" raise RuntimeError(msg) - img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8)) + if roi: + # extract ROI from image + shape = roi["shape"] + h, w = img_data.shape[:2] + x1, y1, x2, y2 = ( + int(np.clip(np.trunc(shape["x1"] * w), 0, w)), + int(np.clip(np.trunc(shape["y1"] * h), 0, h)), + int(np.clip(np.ceil(shape["x2"] * w), 0, w)), + int(np.clip(np.ceil(shape["y2"] * h), 0, h)), + ) + if (x2 - x1) * (y2 - y1) <= 0: + msg = f"ROI has zero or negative area. ROI coordinates: {x1}, {y1}, {x2}, {y2}" + raise ValueError(msg) + + img_data = img_data[y1:y2, x1:x2] + roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)} + + img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta) - return img_data, img_data.shape[:2] + return img_data, img_data.shape[:2], roi_meta - def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: + def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray: """Cache an image after resizing. If there is available space in the memory pool, the input image is cached. @@ -182,14 +220,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: return img_data if self.mem_cache_img_max_size is None: - self.mem_cache_handler.put(key=key, data=img_data, meta=None) + self.mem_cache_handler.put(key=key, data=img_data, meta=meta) return img_data height, width = img_data.shape[:2] max_height, max_width = self.mem_cache_img_max_size if height <= max_height and width <= max_width: - self.mem_cache_handler.put(key=key, data=img_data, meta=None) + self.mem_cache_handler.put(key=key, data=img_data, meta=meta) return img_data # Preserve the image size ratio and fit to max_height or max_width @@ -206,7 +244,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray: self.mem_cache_handler.put( key=key, data=resized_img, - meta=None, + meta=meta, ) return resized_img diff --git a/src/otx/core/data/dataset/classification.py b/src/otx/core/data/dataset/classification.py index c5048dd7987..8f4f5ffc241 100644 --- a/src/otx/core/data/dataset/classification.py +++ b/src/otx/core/data/dataset/classification.py @@ -32,18 +32,18 @@ class OTXMulticlassClsDataset(OTXDataset[MulticlassClsDataEntity]): def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + roi = item.attributes.get("roi", None) + img_data, img_shape, _ = self._get_img_data_and_shape(img, roi) + if roi: + # extract labels from ROI + labels_ids = [ + label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION" + ] + label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids] + else: + # extract labels from annotations + label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)] - label_anns = [] - for ann in item.annotations: - if isinstance(ann, Label): - label_anns.append(ann) - else: - # If the annotation is not Label, it should be converted to Label. - # For Chained Task: Detection (Bbox) -> Classification (Label) - label = Label(label=ann.label) - if label not in label_anns: - label_anns.append(label) if len(label_anns) > 1: msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}" raise ValueError(msg) @@ -56,7 +56,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None: ori_shape=img_shape, image_color_channel=self.image_color_channel, ), - labels=torch.as_tensor([ann.label for ann in label_anns]), + labels=torch.as_tensor(label_anns), ) return self._apply_transforms(entity) @@ -78,7 +78,7 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label_anns = [] for ann in item.annotations: @@ -195,7 +195,7 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) label_anns = [] for ann in item.annotations: diff --git a/src/otx/core/data/dataset/detection.py b/src/otx/core/data/dataset/detection.py index 8094638b457..6783fce7207 100644 --- a/src/otx/core/data/dataset/detection.py +++ b/src/otx/core/data/dataset/detection.py @@ -26,7 +26,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] diff --git a/src/otx/core/data/dataset/instance_segmentation.py b/src/otx/core/data/dataset/instance_segmentation.py index 0a3abaeb877..2457e129344 100644 --- a/src/otx/core/data/dataset/instance_segmentation.py +++ b/src/otx/core/data/dataset/instance_segmentation.py @@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] diff --git a/src/otx/core/data/dataset/keypoint_detection.py b/src/otx/core/data/dataset/keypoint_detection.py index f0e0d30c372..c74b77c9319 100644 --- a/src/otx/core/data/dataset/keypoint_detection.py +++ b/src/otx/core/data/dataset/keypoint_detection.py @@ -86,7 +86,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] # This should be assigned form item - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] bboxes = ( diff --git a/src/otx/core/data/dataset/segmentation.py b/src/otx/core/data/dataset/segmentation.py index 53975456b67..a690dde42ad 100644 --- a/src/otx/core/data/dataset/segmentation.py +++ b/src/otx/core/data/dataset/segmentation.py @@ -202,9 +202,14 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None: item = self.dm_subset[index] img = item.media_as(Image) ignored_labels: list[int] = [] - img_data, img_shape = self._get_img_data_and_shape(img) + roi = item.attributes.get("roi", None) + img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi) if item.annotations: - extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index) + ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape + extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index) + if roi_meta: + extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]] + masks = tv_tensors.Mask(extracted_mask[None]) else: # semi-supervised learning, unlabeled dataset diff --git a/src/otx/core/data/dataset/tile.py b/src/otx/core/data/dataset/tile.py index a39ea5aa90d..a729ddc1869 100644 --- a/src/otx/core/data/dataset/tile.py +++ b/src/otx/core/data/dataset/tile.py @@ -370,7 +370,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr """ item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)] @@ -461,7 +461,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o """ item = self.dm_subset[index] img = item.media_as(Image) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], [] diff --git a/src/otx/core/data/dataset/visual_prompting.py b/src/otx/core/data/dataset/visual_prompting.py index 0047e9350fe..8f2ccb620d9 100644 --- a/src/otx/core/data/dataset/visual_prompting.py +++ b/src/otx/core/data/dataset/visual_prompting.py @@ -79,7 +79,7 @@ def __init__( def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None: item = self.dm_subset[index] img = item.media_as(dmImage) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_bboxes, gt_points = [], [] gt_masks = defaultdict(list) @@ -214,7 +214,7 @@ def __init__( def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None: item = self.dm_subset[index] img = item.media_as(dmImage) - img_data, img_shape = self._get_img_data_and_shape(img) + img_data, img_shape, _ = self._get_img_data_and_shape(img) gt_prompts: list[tvBoundingBoxes | Points] = [] gt_masks: list[tvMask] = [] diff --git a/tests/unit/core/data/dataset/test_base.py b/tests/unit/core/data/dataset/test_base.py new file mode 100644 index 00000000000..47afdf9cf55 --- /dev/null +++ b/tests/unit/core/data/dataset/test_base.py @@ -0,0 +1,104 @@ +from unittest import mock + +import numpy as np +import pytest +from datumaro.components.media import Image +from otx.core.data.dataset.base import OTXDataset + + +class TestOTXDataset: + @pytest.fixture() + def mock_image(self) -> Image: + img = mock.Mock(spec=Image) + img.data = np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8) + img.path = "test_path" + return img + + @pytest.fixture() + def mock_mem_cache_handler(self): + mem_cache_handler = mock.MagicMock() + mem_cache_handler.frozen = False + return mem_cache_handler + + @pytest.fixture() + def otx_dataset(self, mock_mem_cache_handler): + class MockOTXDataset(OTXDataset): + def _get_item_impl(self, idx: int) -> None: + return None + + @property + def collate_fn(self) -> None: + return None + + dm_subset = mock.Mock() + dm_subset.categories = mock.MagicMock() + dm_subset.categories.return_value = None + + return MockOTXDataset( + dm_subset=dm_subset, + transforms=None, + mem_cache_handler=mock_mem_cache_handler, + mem_cache_img_max_size=None, + ) + + def test_get_img_data_and_shape_no_cache(self, otx_dataset, mock_image, mock_mem_cache_handler): + mock_mem_cache_handler.get.return_value = (None, None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image) + assert img_data.shape == (10, 10, 3) + assert img_shape == (10, 10) + assert roi_meta is None + + def test_get_img_data_and_shape_with_cache(self, otx_dataset, mock_image, mock_mem_cache_handler): + mock_mem_cache_handler.get.return_value = (np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8), None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image) + assert img_data.shape == (10, 10, 3) + assert img_shape == (10, 10) + assert roi_meta is None + + def test_get_img_data_and_shape_with_roi(self, otx_dataset, mock_image, mock_mem_cache_handler): + roi = {"shape": {"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}} + mock_mem_cache_handler.get.return_value = (None, None) + img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image, roi) + assert img_data.shape == (8, 8, 3) + assert img_shape == (8, 8) + assert roi_meta == {"x1": 1, "y1": 1, "x2": 9, "y2": 9, "orig_image_shape": (10, 10)} + + def test_cache_img_no_resize(self, otx_dataset): + img_data = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + otx_dataset.mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None) + + def test_cache_img_with_resize(self, otx_dataset, mock_mem_cache_handler): + otx_dataset.mem_cache_img_max_size = (100, 100) + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert cached_img.shape == (100, 100, 3) + mock_mem_cache_handler.put.assert_called_once() + assert mock_mem_cache_handler.put.call_args[1]["data"].shape == (100, 100, 3) + + def test_cache_img_no_max_size(self, otx_dataset, mock_mem_cache_handler): + otx_dataset.mem_cache_img_max_size = None + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + mock_mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None) + + def test_cache_img_frozen_handler(self, otx_dataset, mock_mem_cache_handler): + mock_mem_cache_handler.frozen = True + img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8) + key = "test_key" + + cached_img = otx_dataset._cache_img(key, img_data) + + assert np.array_equal(cached_img, img_data) + mock_mem_cache_handler.put.assert_not_called()