Skip to content

Commit

Permalink
Add error for media type
Browse files Browse the repository at this point in the history
  • Loading branch information
sooahleex committed May 3, 2024
1 parent 742cbfd commit ba98d27
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
6 changes: 4 additions & 2 deletions docs/source/docs/command-reference/context_free/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,11 @@ Examples:
#### `astype_annotations`

Enables the conversion of annotation types for the categories and individual items within a dataset.
This transform only supports tabular datasets. If you want to change annotation types in datasets of other types, please use a different transform.

Based on a specified mapping, it transforms the annotation types, changing them to 'Label' if they are categorical,
Based on default setting it transforms the annotation types, changing them to 'Label' if they are categorical,
and to 'Caption' if they are of type string, float, or integer.
If you specifically set mapping, change annotation types based on the mapping.

Usage:
```console
Expand All @@ -519,7 +521,7 @@ Examples:
- Convert type of `title` and `rating` annotation
```console
datum transform -t astype_annotations -- \
--mapping 'title:Caption,rating:int'
--mapping 'title:text,rating:int'
```

#### `random_split`
Expand Down
37 changes: 22 additions & 15 deletions src/datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
)
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetInfo, DatasetItem, IDataset
from datumaro.components.errors import DatumaroError
from datumaro.components.media import Image
from datumaro.components.errors import DatumaroError, MediaTypeError
from datumaro.components.media import Image, TableRow
from datumaro.components.transformer import ItemTransform, Transform
from datumaro.util import NOTSET, filter_dict, parse_json_file, parse_str_enum_value, take_by
from datumaro.util.annotation_util import find_group_leader, find_instances
Expand Down Expand Up @@ -1495,9 +1495,13 @@ def __init__(
):
super().__init__(extractor)

# Turn off for default setting
# assert isinstance(mapping, (dict, list))
if extractor.media_type() and not issubclass(extractor.media_type(), TableRow):
raise MediaTypeError(
"Media type is not table. This transform only support tabular media"
)

# Turn off for default setting
assert mapping is None or isinstance(mapping, (dict, list)), "Mapping must be dict, or list"
if isinstance(mapping, list):
mapping = dict(mapping)

Check warning on line 1506 in src/datumaro/plugins/transforms.py

View check run for this annotation

Codecov / codecov/patch

src/datumaro/plugins/transforms.py#L1506

Added line #L1506 was not covered by tests

Expand All @@ -1510,17 +1514,20 @@ def __init__(
# Make LabelCategories
self._id_mapping = {}
dst_label_cat = LabelCategories()
if src_tabular_cat is not None:
for src_cat in src_tabular_cat:
if src_cat.dtype == CategoricalDtype():
dst_parent = src_cat.name
dst_labels = sorted(src_cat.labels)
for dst_label in dst_labels:
dst_index = dst_label_cat.add(dst_label, parent=dst_parent, attributes={})
self._id_mapping[dst_label] = dst_index
dst_label_cat.add_label_group(src_cat.name, src_cat.labels, group_type=0)
self._tabular_cat_types[src_cat.name] = src_cat.dtype
self._categories[AnnotationType.label] = dst_label_cat

if src_tabular_cat is None:
return

for src_cat in src_tabular_cat:
if src_cat.dtype == CategoricalDtype():
dst_parent = src_cat.name
dst_labels = sorted(src_cat.labels)
for dst_label in dst_labels:
dst_index = dst_label_cat.add(dst_label, parent=dst_parent, attributes={})
self._id_mapping[dst_label] = dst_index
dst_label_cat.add_label_group(src_cat.name, src_cat.labels, group_type=0)
self._tabular_cat_types[src_cat.name] = src_cat.dtype
self._categories[AnnotationType.label] = dst_label_cat

def categories(self):
return self._categories
Expand Down
23 changes: 21 additions & 2 deletions tests/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.errors import MediaTypeError
from datumaro.components.media import Image, Table, TableRow

from ..requirements import Requirements, mark_bug, mark_requirement
Expand Down Expand Up @@ -1277,15 +1278,33 @@ def test_split_arg_valid(self):
assert transforms.AstypeAnnotations._split_arg("date:label") == [("date", "label")]

# Test valid input with multiple colons
assert transforms.AstypeAnnotations._split_arg("date:label,title:caption") == [
assert transforms.AstypeAnnotations._split_arg("date:label,title:text") == [
("date", "label"),
("title", "caption"),
("title", "text"),
]

# Test invalid input with no colon
with pytest.raises(argparse.ArgumentTypeError):
transforms.AstypeAnnotations._split_arg("datelabel")

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_media_type(self):
dataset = Dataset.from_iterable(
[
DatasetItem(
id="1",
subset="train",
annotations=[
Label(0, id=0),
],
),
],
categories={},
)

with self.assertRaises(MediaTypeError):
transforms.AstypeAnnotations(dataset)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_transform_annotation_type_label(self):
table = self.table
Expand Down

0 comments on commit ba98d27

Please sign in to comment.