Skip to content

Commit

Permalink
Fix SegmenatationDatasetAdapter, AnomalyDatasetAdapter, and
Browse files Browse the repository at this point in the history
ActionBaseDatasetAdapter signatures

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim committed May 31, 2023
1 parent 4be1f3c commit 1dcee61
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 8 deletions.
5 changes: 4 additions & 1 deletion otx/core/data/adapter/action_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ActionBaseDatasetAdapter(BaseDatasetAdapter):
VIDEO_FRAME_SEP = "##"
EMPTY_FRAME_LABEL_NAME = "EmptyFrame"

def _import_dataset(
def _import_datasets(
self,
train_data_roots: Optional[str] = None,
train_ann_files: Optional[str] = None,
Expand All @@ -39,6 +39,7 @@ def _import_dataset(
test_ann_files: Optional[str] = None,
unlabeled_data_roots: Optional[str] = None,
unlabeled_file_list: Optional[str] = None,
encryption_key: Optional[str] = None,
) -> Dict[Subset, DatumDataset]:
"""Import multiple videos that have CVAT format annotation.
Expand All @@ -51,6 +52,8 @@ def _import_dataset(
test_ann_files (Optional[str]): Path for test annotation file
unlabeled_data_roots (Optional[str]): Path for unlabeled data
unlabeled_file_list (Optional[str]): Path of unlabeled file list
encryption_key (Optional[str]): Encryption key to load an encrypted dataset
(only required for DatumaroBinary format)
Returns:
DatumDataset: Datumaro Dataset
Expand Down
5 changes: 4 additions & 1 deletion otx/core/data/adapter/anomaly_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class AnomalyBaseDatasetAdapter(BaseDatasetAdapter):
"""BaseDataset Adpater for Anomaly tasks inherited from BaseDatasetAdapter."""

def _import_dataset(
def _import_datasets(
self,
train_data_roots: Optional[str] = None,
train_ann_files: Optional[str] = None,
Expand All @@ -44,6 +44,7 @@ def _import_dataset(
test_ann_files: Optional[str] = None,
unlabeled_data_roots: Optional[str] = None,
unlabeled_file_list: Optional[str] = None,
encryption_key: Optional[str] = None,
) -> Dict[Subset, DatumaroDataset]:
"""Import MVTec dataset.
Expand All @@ -56,6 +57,8 @@ def _import_dataset(
test_ann_files (Optional[str]): Path for test annotation file
unlabeled_data_roots (Optional[str]): Path for unlabeled data
unlabeled_file_list (Optional[str]): Path of unlabeled file list
encryption_key (Optional[str]): Encryption key to load an encrypted dataset
(only required for DatumaroBinary format)
Returns:
DatumaroDataset: Datumaro Dataset
Expand Down
2 changes: 1 addition & 1 deletion otx/core/data/adapter/base_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BaseDatasetAdapter(metaclass=abc.ABCMeta):
unlabeled_data_roots (Optional[str]): Path for unlabeled data
unlabeled_file_list (Optional[str]): Path of unlabeled file list
encryption_key (Optional[str]): Encryption key to load an encrypted dataset
(DatumaroBinary format)
(only required for DatumaroBinary format)
Since all adapters can be used for training and validation,
the default value of train/val/test_data_roots was set to None.
Expand Down
5 changes: 4 additions & 1 deletion otx/core/data/adapter/segmentation_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class SelfSLSegmentationDatasetAdapter(SegmentationDatasetAdapter):
"""Self-SL for segmentation adapter inherited from SegmentationDatasetAdapter."""

# pylint: disable=protected-access
def _import_dataset(
def _import_datasets(
self,
train_data_roots: Optional[str] = None,
train_ann_files: Optional[str] = None,
Expand All @@ -166,6 +166,7 @@ def _import_dataset(
test_ann_files: Optional[str] = None,
unlabeled_data_roots: Optional[str] = None,
unlabeled_file_list: Optional[str] = None,
encryption_key: Optional[str] = None,
pseudo_mask_dir: str = "detcon_mask",
) -> Dict[Subset, DatumDataset]:
"""Import custom Self-SL dataset for using DetCon.
Expand All @@ -183,6 +184,8 @@ def _import_dataset(
test_ann_files (Optional[str]): Path for test annotation file
unlabeled_data_roots (Optional[str]): Path for unlabeled data.
unlabeled_file_list (Optional[str]): Path of unlabeled file list
encryption_key (Optional[str]): Encryption key to load an encrypted dataset
(only required for DatumaroBinary format)
pseudo_mask_dir (str): Directory to save pseudo masks. Defaults to "detcon_mask".
Returns:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/core/data/adapter/test_segmentation_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setup_method(self, method) -> None:

@e2e_pytest_unit
def test_import_dataset_create_all_masks(self, mocker):
"""Test _import_dataset when creating all masks.
"""Test _import_datasets when creating all masks.
This test is for when all masks are not created and it is required to create masks.
"""
Expand All @@ -96,7 +96,7 @@ def test_import_dataset_create_all_masks(self, mocker):
@e2e_pytest_unit
@pytest.mark.parametrize("idx_remove", [1, 2, 3])
def test_import_dataset_create_some_uncreated_masks(self, mocker, idx_remove: int):
"""Test _import_dataset when there are both uncreated and created masks.
"""Test _import_datasets when there are both uncreated and created masks.
This test is for when there are both created and uncreated masks
and it is required to either create or just load masks.
Expand All @@ -114,7 +114,7 @@ def test_import_dataset_create_some_uncreated_masks(self, mocker, idx_remove: in
os.remove(os.path.join(self.pseudo_mask_roots, f"000{idx_remove}.png"))
spy_create_pseudo_masks = mocker.spy(SelfSLSegmentationDatasetAdapter, "create_pseudo_masks")

_ = dataset_adapter._import_dataset(
_ = dataset_adapter._import_datasets(
train_data_roots=self.train_data_roots,
)

Expand All @@ -123,7 +123,7 @@ def test_import_dataset_create_some_uncreated_masks(self, mocker, idx_remove: in

@e2e_pytest_unit
def test_import_dataset_just_load_masks(self, mocker):
"""Test _import_dataset when just loading all masks."""
"""Test _import_datasets when just loading all masks."""
spy_create_pseudo_masks = mocker.spy(SelfSLSegmentationDatasetAdapter, "create_pseudo_masks")

_ = SelfSLSegmentationDatasetAdapter(
Expand Down

0 comments on commit 1dcee61

Please sign in to comment.