diff --git a/CHANGELOG.md b/CHANGELOG.md index dd78dee5bc..d177cf58a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for CIFAR-10/100 dataset format () - Support COCO panoptic and stuff format () - Documentation file and integration tests for Pascal VOC format () +- Support for MNIST and MNIST in CSV dataset formats () ### Changed - LabelMe format saves dataset items with their relative paths by subsets without changing names () diff --git a/README.md b/README.md index 00697b4bbe..7bd9d6252b 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,8 @@ CVAT annotations ---> Publication, statistics etc. - [MOTS PNG](https://www.vision.rwth-aachen.de/page/mots) - [ImageNet](http://image-net.org/) - [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) (`classification`) + - [MNIST](http://yann.lecun.com/exdb/mnist/) (`classification`) + - [MNIST in CSV](https://pjreddie.com/projects/mnist-in-csv/) (`classification`) - [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) - [CVAT](https://github.com/opencv/cvat/blob/develop/cvat/apps/documentation/xml_format.md) - [LabelMe](http://labelme.csail.mit.edu/Release3.0) diff --git a/datumaro/plugins/mnist_csv_format.py b/datumaro/plugins/mnist_csv_format.py new file mode 100644 index 0000000000..ae0fa8bf8c --- /dev/null +++ b/datumaro/plugins/mnist_csv_format.py @@ -0,0 +1,170 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import os +import os.path as osp + +import numpy as np +from datumaro.components.converter import Converter +from datumaro.components.extractor import (AnnotationType, DatasetItem, + Importer, Label, LabelCategories, SourceExtractor) + + +class MnistCsvPath: + IMAGE_SIZE = 28 + NONE_LABEL = -1 + +class MnistCsvExtractor(SourceExtractor): + def __init__(self, path, subset=None): + if not osp.isfile(path): + raise FileNotFoundError("Can't read annotation file '%s'" % path) + + if not subset: + file_name = osp.splitext(osp.basename(path))[0] + subset = file_name.rsplit('_', maxsplit=1)[-1] + + super().__init__(subset=subset) + self._dataset_dir = osp.dirname(path) + + self._categories = self._load_categories() + + self._items = list(self._load_items(path).values()) + + def _load_categories(self): + label_cat = LabelCategories() + + labels_file = osp.join(self._dataset_dir, 'labels.txt') + if osp.isfile(labels_file): + with open(labels_file, encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + label_cat.add(line) + else: + for i in range(10): + label_cat.add(str(i)) + + return { AnnotationType.label: label_cat } + + def _load_items(self, path): + items = {} + with open(path, 'r', encoding='utf-8') as f: + annotation_table = f.readlines() + + metafile = osp.join(self._dataset_dir, 'meta_%s.csv' % self._subset) + meta = [] + if osp.isfile(metafile): + with open(metafile, 'r', encoding='utf-8') as f: + meta = f.readlines() + + for i, data in enumerate(annotation_table): + data = data.split(',') + item_anno = [] + label = int(data[0]) + if label != MnistCsvPath.NONE_LABEL: + item_anno.append(Label(label)) + + if 0 < len(meta): + meta[i] = meta[i].strip().split(',') + + # support for single-channel image only + image = None + if 1 < len(data): + if 0 < len(meta) and 1 < len(meta[i]): + image = np.array([int(pix) for pix in data[1:]], + dtype='uint8').reshape(int(meta[i][-2]), int(meta[i][-1])) + else: + image = np.array([int(pix) for pix in data[1:]], + dtype='uint8').reshape(28, 28) + + if 0 < len(meta) and len(meta[i]) in [1, 3]: + i = meta[i][0] + + items[i] = DatasetItem(id=i, subset=self._subset, + image=image, annotations=item_anno) + return items + +class MnistCsvImporter(Importer): + @classmethod + def find_sources(cls, path): + return cls._find_sources_recursive(path, '.csv', 'mnist_csv', + file_filter=lambda p: not osp.basename(p).startswith('meta')) + +class MnistCsvConverter(Converter): + DEFAULT_IMAGE_EXT = '.png' + + def apply(self): + os.makedirs(self._save_dir, exist_ok=True) + for subset_name, subset in self._extractor.subsets().items(): + data = [] + item_ids = {} + image_sizes = {} + for item in subset: + anns = [a.label for a in item.annotations + if a.type == AnnotationType.label] + label = MnistCsvPath.NONE_LABEL + if anns: + label = anns[0] + + if item.has_image and self._save_images: + image = item.image + if not image.has_data: + data.append([label, None]) + else: + if image.data.shape[0] != MnistCsvPath.IMAGE_SIZE or \ + image.data.shape[1] != MnistCsvPath.IMAGE_SIZE: + image_sizes[len(data)] = [image.data.shape[0], + image.data.shape[1]] + image = image.data.reshape(-1).astype(np.uint8).tolist() + image.insert(0, label) + data.append(image) + else: + data.append([label]) + + if item.id != str(len(data) - 1): + item_ids[len(data) - 1] = item.id + + anno_file = osp.join(self._save_dir, 'mnist_%s.csv' % subset_name) + self.save_in_csv(anno_file, data) + + # it is't in the original format, + # this is for storng other names and sizes of images + if len(item_ids) or len(image_sizes): + meta = [] + if len(item_ids) and len(image_sizes): + # other names and sizes of images + size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] + for i in range(len(data)): + w, h = image_sizes.get(i, size) + meta.append([item_ids.get(i, i), w, h]) + + elif len(item_ids): + # other names of images + for i in range(len(data)): + meta.append([item_ids.get(i, i)]) + + elif len(image_sizes): + # other sizes of images + size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] + for i in range(len(data)): + meta.append(image_sizes.get(i, size)) + + metafile = osp.join(self._save_dir, 'meta_%s.csv' % subset_name) + self.save_in_csv(metafile, meta) + + self.save_labels() + + def save_in_csv(self, path, data): + with open(path, 'w', encoding='utf-8') as f: + for row in data: + f.write(','.join([str(p) for p in row]) + "\n") + + def save_labels(self): + labels_file = osp.join(self._save_dir, 'labels.txt') + with open(labels_file, 'w', encoding='utf-8') as f: + f.writelines(l.name + '\n' + for l in self._extractor.categories().get( + AnnotationType.label, LabelCategories()) + ) diff --git a/datumaro/plugins/mnist_format.py b/datumaro/plugins/mnist_format.py new file mode 100644 index 0000000000..0cd97b06df --- /dev/null +++ b/datumaro/plugins/mnist_format.py @@ -0,0 +1,209 @@ +# Copyright (C) 2021 Intel Corporation +# +# SPDX-License-Identifier: MIT + +import gzip +import os +import os.path as osp + +import numpy as np +from datumaro.components.converter import Converter +from datumaro.components.extractor import (AnnotationType, DatasetItem, + Importer, Label, LabelCategories, SourceExtractor) + + +class MnistPath: + TEST_LABELS_FILE = 't10k-labels-idx1-ubyte.gz' + TEST_IMAGES_FILE = 't10k-images-idx3-ubyte.gz' + LABELS_FILE = '-labels-idx1-ubyte.gz' + IMAGES_FILE = '-images-idx3-ubyte.gz' + IMAGE_SIZE = 28 + NONE_LABEL = 255 + +class MnistExtractor(SourceExtractor): + def __init__(self, path, subset=None): + if not osp.isfile(path): + raise FileNotFoundError("Can't read annotation file '%s'" % path) + + if not subset: + file_name = osp.splitext(osp.basename(path))[0] + if file_name.startswith('t10k'): + subset = 'test' + else: + subset = file_name.split('-', maxsplit=1)[0] + + super().__init__(subset=subset) + self._dataset_dir = osp.dirname(path) + + self._categories = self._load_categories() + + self._items = list(self._load_items(path).values()) + + def _load_categories(self): + label_cat = LabelCategories() + + labels_file = osp.join(self._dataset_dir, 'labels.txt') + if osp.isfile(labels_file): + with open(labels_file, encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + label_cat.add(line) + else: + for i in range(10): + label_cat.add(str(i)) + + return { AnnotationType.label: label_cat } + + def _load_items(self, path): + items = {} + with gzip.open(path, 'rb') as lbpath: + labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) + + meta = [] + metafile = osp.join(self._dataset_dir, self._subset + '-meta.gz') + if osp.isfile(metafile): + with gzip.open(metafile, 'rb') as f: + meta = np.frombuffer(f.read(), dtype=' +# or +datum create +datum add path -f mnist +``` + +There are two ways to create Datumaro project and add MNIST in CSV dataset to it: + +``` bash +datum import --format mnist_csv --input-path +# or +datum create +datum add path -f mnist_csv +``` + +It is possible to specify project name and project directory run +`datum create --help` for more information. + +MNIST dataset directory should have the following structure: + + +``` +└─ Dataset/ + ├── labels.txt # list of non-digit labels (optional) + ├── t10k-images-idx3-ubyte.gz + ├── t10k-labels-idx1-ubyte.gz + ├── train-images-idx3-ubyte.gz + └── train-labels-idx1-ubyte.gz +``` +MNIST in CSV dataset directory should have the following structure: + + +``` +└─ Dataset/ + ├── labels.txt # list of non-digit labels (optional) + ├── mnist_test.csv + └── mnist_train.csv +``` +If the dataset needs non-digit labels, you need to add the labels.txt +to the dataset folder. +For example, labels.txt for Fashion MNIST labels contains the following: + +``` +T-shirt/top +Trouser +Pullover +Dress +Coat +Sandal +Shirt +Sneaker +Bag +Ankle boot +``` + +MNIST format only supports single channel 28 x 28 images. + +## Export to other formats + +Datumaro can convert MNIST dataset into any other format [Datumaro supports](../docs/user_manual.md#supported-formats). +To get the expected result, the dataset needs to be converted to formats +that support the classification task (e.g. CIFAR-10/100, ImageNet, PascalVOC, etc.) +There are few ways to convert MNIST dataset to other dataset format: + +``` bash +datum project import -f mnist -i +datum export -f imagenet -o +# or +datum convert -if mnist -i -f imagenet -o +``` + +These commands also work for MNIST in CSV if you use `mnist_csv` instead of `mnist`. + +## Export to MNIST + +There are few ways to convert dataset to MNIST format: + +``` bash +# export dataset into MNIST format from existing project +datum export -p -f mnist -o \ + -- --save-images +# converting to MNIST format from other format +datum convert -if imagenet -i \ + -f mnist -o -- --save-images +``` + +Extra options for export to MNIST format: + +- `--save-images` allow to export dataset with saving images +(by default `False`); +- `--image-ext ` allow to specify image extension +for exporting dataset (by default `.png`). + +These commands also work for MNIST in CSV if you use `mnist_csv` instead of `mnist`. + +## Particular use cases + +Datumaro supports filtering, transformation, merging etc. for all formats +and for the MNIST format in particular. Follow [user manual](../docs/user_manual.md) +to get more information about these operations. + +There are few examples of using Datumaro operations to solve +particular problems with MNIST dataset: + +### Example 1. How to create custom MNIST-like dataset + +```python +from datumaro.components.dataset import Dataset +from datumaro.components.extractor import Label, DatasetItem + +dataset = Dataset.from_iterable([ + DatasetItem(id=0, image=np.ones((28, 28)), + annotations=[Label(2)] + ), + DatasetItem(id=1, image=np.ones((28, 28)), + annotations=[Label(7)] + ) +], categories=[str(label) for label in range(10)]) + +dataset.export('./dataset', format='mnist') +``` + +### Example 2. How to filter and convert MNIST dataset to ImageNet + +Convert MNIST dataset to ImageNet format, keep only images with `3` class presented: + +``` bash +# Download MNIST dataset: +# https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz +# https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz +datum convert --input-format mnist --input-path \ + --output-format imagenet \ + --filter '/item[annotation/label="3"]' +``` diff --git a/docs/user_manual.md b/docs/user_manual.md index 0f1f57b336..550f5e88f5 100644 --- a/docs/user_manual.md +++ b/docs/user_manual.md @@ -118,6 +118,14 @@ List of supported formats: - CIFAR-10/100 (`classification` (python version)) - [Format specification](https://www.cs.toronto.edu/~kriz/cifar.html) - [Dataset example](../tests/assets/cifar_dataset) +- MNIST (`classification`) + - [Format specification](http://yann.lecun.com/exdb/mnist/) + - [Dataset example](../tests/assets/mnist_dataset) + - [Format documentation](./mnist_user_manual.md) +- MNIST in CSV (`classification`) + - [Format specification](https://pjreddie.com/projects/mnist-in-csv/) + - [Dataset example](../tests/assets/mnist_csv_dataset) + - [Format documentation](./mnist_user_manual.md) - CamVid (`segmentation`) - [Format specification](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) - [Dataset example](../tests/assets/camvid_dataset) diff --git a/tests/assets/mnist_csv_dataset/mnist_test.csv b/tests/assets/mnist_csv_dataset/mnist_test.csv new file mode 100644 index 0000000000..d07be3847f --- /dev/null +++ b/tests/assets/mnist_csv_dataset/mnist_test.csv @@ -0,0 +1,3 @@ +0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +2,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 diff --git a/tests/assets/mnist_csv_dataset/mnist_train.csv b/tests/assets/mnist_csv_dataset/mnist_train.csv new file mode 100644 index 0000000000..c93c7fca93 --- /dev/null +++ b/tests/assets/mnist_csv_dataset/mnist_train.csv @@ -0,0 +1,2 @@ +5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +7,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 diff --git a/tests/assets/mnist_dataset/t10k-images-idx3-ubyte.gz b/tests/assets/mnist_dataset/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000000..811cde6b46 Binary files /dev/null and b/tests/assets/mnist_dataset/t10k-images-idx3-ubyte.gz differ diff --git a/tests/assets/mnist_dataset/t10k-labels-idx1-ubyte.gz b/tests/assets/mnist_dataset/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..a06f7a317c Binary files /dev/null and b/tests/assets/mnist_dataset/t10k-labels-idx1-ubyte.gz differ diff --git a/tests/assets/mnist_dataset/train-images-idx3-ubyte.gz b/tests/assets/mnist_dataset/train-images-idx3-ubyte.gz new file mode 100644 index 0000000000..10331a2c6c Binary files /dev/null and b/tests/assets/mnist_dataset/train-images-idx3-ubyte.gz differ diff --git a/tests/assets/mnist_dataset/train-labels-idx1-ubyte.gz b/tests/assets/mnist_dataset/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000000..8b193716ee Binary files /dev/null and b/tests/assets/mnist_dataset/train-labels-idx1-ubyte.gz differ diff --git a/tests/test_mnist_csv_format.py b/tests/test_mnist_csv_format.py new file mode 100644 index 0000000000..17286b90fe --- /dev/null +++ b/tests/test_mnist_csv_format.py @@ -0,0 +1,185 @@ +import os.path as osp +from unittest import TestCase + +import numpy as np +from datumaro.components.dataset import Dataset +from datumaro.components.extractor import (AnnotationType, DatasetItem, Label, + LabelCategories) +from datumaro.plugins.mnist_csv_format import (MnistCsvConverter, + MnistCsvImporter) +from datumaro.util.image import Image +from datumaro.util.test_utils import TestDir, compare_datasets + + +class MnistCsvFormatTest(TestCase): + def test_can_save_and_load(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='test', + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='test', + image=np.ones((28, 28)) + ), + DatasetItem(id=2, subset='test', + image=np.ones((28, 28)), + annotations=[Label(1)] + ) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_without_saving_images(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train', + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='train', + annotations=[Label(1)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(source_dataset, test_dir, save_images=False) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_with_different_image_size(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, image=np.ones((10, 8)), + annotations=[Label(0)] + ), + DatasetItem(id=1, image=np.ones((4, 3)), + annotations=[Label(1)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_dataset_with_cyrillic_and_spaces_in_filename(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id="кириллица с пробелом", + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_image_with_arbitrary_extension(self): + dataset = Dataset.from_iterable([ + DatasetItem(id='q/1', image=Image(path='q/1.JPEG', + data=np.zeros((28, 28)))), + DatasetItem(id='a/b/c/2', image=Image(path='a/b/c/2.bmp', + data=np.zeros((28, 28)))), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_empty_image(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, annotations=[Label(0)]), + DatasetItem(id=1) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_with_other_labels(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, image=np.ones((28, 28)), + annotations=[Label(0)]), + DatasetItem(id=1, image=np.ones((28, 28)), + annotations=[Label(1)]) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + 'label_%s' % label for label in range(2)), + }) + + with TestDir() as test_dir: + MnistCsvConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist_csv') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + +DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'mnist_csv_dataset') + +class MnistCsvImporterTest(TestCase): + def test_can_import(self): + expected_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='test', + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='test', + image=np.ones((28, 28)), + annotations=[Label(2)] + ), + DatasetItem(id=2, subset='test', + image=np.ones((28, 28)), + annotations=[Label(1)] + ), + DatasetItem(id=0, subset='train', + image=np.ones((28, 28)), + annotations=[Label(5)] + ), + DatasetItem(id=1, subset='train', + image=np.ones((28, 28)), + annotations=[Label(7)] + ) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + dataset = Dataset.import_from(DUMMY_DATASET_DIR, 'mnist_csv') + + compare_datasets(self, expected_dataset, dataset) + + def test_can_detect(self): + self.assertTrue(MnistCsvImporter.detect(DUMMY_DATASET_DIR)) diff --git a/tests/test_mnist_format.py b/tests/test_mnist_format.py new file mode 100644 index 0000000000..eb5f5299b1 --- /dev/null +++ b/tests/test_mnist_format.py @@ -0,0 +1,184 @@ +import os.path as osp +from unittest import TestCase + +import numpy as np +from datumaro.components.dataset import Dataset +from datumaro.components.extractor import (AnnotationType, DatasetItem, Label, + LabelCategories) +from datumaro.plugins.mnist_format import MnistConverter, MnistImporter +from datumaro.util.image import Image +from datumaro.util.test_utils import TestDir, compare_datasets + + +class MnistFormatTest(TestCase): + def test_can_save_and_load(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='test', + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='test', + image=np.ones((28, 28)) + ), + DatasetItem(id=2, subset='test', + image=np.ones((28, 28)), + annotations=[Label(1)] + ) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_without_saving_images(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='train', + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='train', + annotations=[Label(1)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(source_dataset, test_dir, save_images=False) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_with_different_image_size(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id=0, image=np.ones((3, 4)), + annotations=[Label(0)] + ), + DatasetItem(id=1, image=np.ones((2, 2)), + annotations=[Label(1)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_dataset_with_cyrillic_and_spaces_in_filename(self): + source_dataset = Dataset.from_iterable([ + DatasetItem(id="кириллица с пробелом", + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(source_dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, source_dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_image_with_arbitrary_extension(self): + dataset = Dataset.from_iterable([ + DatasetItem(id='q/1', image=Image(path='q/1.JPEG', + data=np.zeros((28, 28)))), + DatasetItem(id='a/b/c/2', image=Image(path='a/b/c/2.bmp', + data=np.zeros((28, 28)))), + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_empty_image(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, annotations=[Label(0)]), + DatasetItem(id=1) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + + def test_can_save_and_load_with_other_labels(self): + dataset = Dataset.from_iterable([ + DatasetItem(id=0, image=np.ones((28, 28)), + annotations=[Label(0)]), + DatasetItem(id=1, image=np.ones((28, 28)), + annotations=[Label(1)]) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + 'label_%s' % label for label in range(2)), + }) + + with TestDir() as test_dir: + MnistConverter.convert(dataset, test_dir, save_images=True) + parsed_dataset = Dataset.import_from(test_dir, 'mnist') + + compare_datasets(self, dataset, parsed_dataset, + require_images=True) + +DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'mnist_dataset') + +class MnistImporterTest(TestCase): + def test_can_import(self): + expected_dataset = Dataset.from_iterable([ + DatasetItem(id=0, subset='test', + image=np.ones((28, 28)), + annotations=[Label(0)] + ), + DatasetItem(id=1, subset='test', + image=np.ones((28, 28)), + annotations=[Label(2)] + ), + DatasetItem(id=2, subset='test', + image=np.ones((28, 28)), + annotations=[Label(1)] + ), + DatasetItem(id=0, subset='train', + image=np.ones((28, 28)), + annotations=[Label(5)] + ), + DatasetItem(id=1, subset='train', + image=np.ones((28, 28)), + annotations=[Label(7)] + ) + ], categories={ + AnnotationType.label: LabelCategories.from_iterable( + str(label) for label in range(10)), + }) + + dataset = Dataset.import_from(DUMMY_DATASET_DIR, 'mnist') + + compare_datasets(self, expected_dataset, dataset) + + def test_can_detect(self): + self.assertTrue(MnistImporter.detect(DUMMY_DATASET_DIR))