-
Notifications
You must be signed in to change notification settings - Fork 135
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for MNIST dataset format (#234)
* add mnist format * add mnist csv format * add mnist to documentation
- Loading branch information
1 parent
ef003ca
commit ca9f78e
Showing
14 changed files
with
940 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (C) 2021 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import os | ||
import os.path as osp | ||
|
||
import numpy as np | ||
from datumaro.components.converter import Converter | ||
from datumaro.components.extractor import (AnnotationType, DatasetItem, | ||
Importer, Label, LabelCategories, SourceExtractor) | ||
|
||
|
||
class MnistCsvPath: | ||
IMAGE_SIZE = 28 | ||
NONE_LABEL = -1 | ||
|
||
class MnistCsvExtractor(SourceExtractor): | ||
def __init__(self, path, subset=None): | ||
if not osp.isfile(path): | ||
raise FileNotFoundError("Can't read annotation file '%s'" % path) | ||
|
||
if not subset: | ||
file_name = osp.splitext(osp.basename(path))[0] | ||
subset = file_name.rsplit('_', maxsplit=1)[-1] | ||
|
||
super().__init__(subset=subset) | ||
self._dataset_dir = osp.dirname(path) | ||
|
||
self._categories = self._load_categories() | ||
|
||
self._items = list(self._load_items(path).values()) | ||
|
||
def _load_categories(self): | ||
label_cat = LabelCategories() | ||
|
||
labels_file = osp.join(self._dataset_dir, 'labels.txt') | ||
if osp.isfile(labels_file): | ||
with open(labels_file, encoding='utf-8') as f: | ||
for line in f: | ||
line = line.strip() | ||
if not line: | ||
continue | ||
label_cat.add(line) | ||
else: | ||
for i in range(10): | ||
label_cat.add(str(i)) | ||
|
||
return { AnnotationType.label: label_cat } | ||
|
||
def _load_items(self, path): | ||
items = {} | ||
with open(path, 'r', encoding='utf-8') as f: | ||
annotation_table = f.readlines() | ||
|
||
metafile = osp.join(self._dataset_dir, 'meta_%s.csv' % self._subset) | ||
meta = [] | ||
if osp.isfile(metafile): | ||
with open(metafile, 'r', encoding='utf-8') as f: | ||
meta = f.readlines() | ||
|
||
for i, data in enumerate(annotation_table): | ||
data = data.split(',') | ||
item_anno = [] | ||
label = int(data[0]) | ||
if label != MnistCsvPath.NONE_LABEL: | ||
item_anno.append(Label(label)) | ||
|
||
if 0 < len(meta): | ||
meta[i] = meta[i].strip().split(',') | ||
|
||
# support for single-channel image only | ||
image = None | ||
if 1 < len(data): | ||
if 0 < len(meta) and 1 < len(meta[i]): | ||
image = np.array([int(pix) for pix in data[1:]], | ||
dtype='uint8').reshape(int(meta[i][-2]), int(meta[i][-1])) | ||
else: | ||
image = np.array([int(pix) for pix in data[1:]], | ||
dtype='uint8').reshape(28, 28) | ||
|
||
if 0 < len(meta) and len(meta[i]) in [1, 3]: | ||
i = meta[i][0] | ||
|
||
items[i] = DatasetItem(id=i, subset=self._subset, | ||
image=image, annotations=item_anno) | ||
return items | ||
|
||
class MnistCsvImporter(Importer): | ||
@classmethod | ||
def find_sources(cls, path): | ||
return cls._find_sources_recursive(path, '.csv', 'mnist_csv', | ||
file_filter=lambda p: not osp.basename(p).startswith('meta')) | ||
|
||
class MnistCsvConverter(Converter): | ||
DEFAULT_IMAGE_EXT = '.png' | ||
|
||
def apply(self): | ||
os.makedirs(self._save_dir, exist_ok=True) | ||
for subset_name, subset in self._extractor.subsets().items(): | ||
data = [] | ||
item_ids = {} | ||
image_sizes = {} | ||
for item in subset: | ||
anns = [a.label for a in item.annotations | ||
if a.type == AnnotationType.label] | ||
label = MnistCsvPath.NONE_LABEL | ||
if anns: | ||
label = anns[0] | ||
|
||
if item.has_image and self._save_images: | ||
image = item.image | ||
if not image.has_data: | ||
data.append([label, None]) | ||
else: | ||
if image.data.shape[0] != MnistCsvPath.IMAGE_SIZE or \ | ||
image.data.shape[1] != MnistCsvPath.IMAGE_SIZE: | ||
image_sizes[len(data)] = [image.data.shape[0], | ||
image.data.shape[1]] | ||
image = image.data.reshape(-1).astype(np.uint8).tolist() | ||
image.insert(0, label) | ||
data.append(image) | ||
else: | ||
data.append([label]) | ||
|
||
if item.id != str(len(data) - 1): | ||
item_ids[len(data) - 1] = item.id | ||
|
||
anno_file = osp.join(self._save_dir, 'mnist_%s.csv' % subset_name) | ||
self.save_in_csv(anno_file, data) | ||
|
||
# it is't in the original format, | ||
# this is for storng other names and sizes of images | ||
if len(item_ids) or len(image_sizes): | ||
meta = [] | ||
if len(item_ids) and len(image_sizes): | ||
# other names and sizes of images | ||
size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] | ||
for i in range(len(data)): | ||
w, h = image_sizes.get(i, size) | ||
meta.append([item_ids.get(i, i), w, h]) | ||
|
||
elif len(item_ids): | ||
# other names of images | ||
for i in range(len(data)): | ||
meta.append([item_ids.get(i, i)]) | ||
|
||
elif len(image_sizes): | ||
# other sizes of images | ||
size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE] | ||
for i in range(len(data)): | ||
meta.append(image_sizes.get(i, size)) | ||
|
||
metafile = osp.join(self._save_dir, 'meta_%s.csv' % subset_name) | ||
self.save_in_csv(metafile, meta) | ||
|
||
self.save_labels() | ||
|
||
def save_in_csv(self, path, data): | ||
with open(path, 'w', encoding='utf-8') as f: | ||
for row in data: | ||
f.write(','.join([str(p) for p in row]) + "\n") | ||
|
||
def save_labels(self): | ||
labels_file = osp.join(self._save_dir, 'labels.txt') | ||
with open(labels_file, 'w', encoding='utf-8') as f: | ||
f.writelines(l.name + '\n' | ||
for l in self._extractor.categories().get( | ||
AnnotationType.label, LabelCategories()) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
# Copyright (C) 2021 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import gzip | ||
import os | ||
import os.path as osp | ||
|
||
import numpy as np | ||
from datumaro.components.converter import Converter | ||
from datumaro.components.extractor import (AnnotationType, DatasetItem, | ||
Importer, Label, LabelCategories, SourceExtractor) | ||
|
||
|
||
class MnistPath: | ||
TEST_LABELS_FILE = 't10k-labels-idx1-ubyte.gz' | ||
TEST_IMAGES_FILE = 't10k-images-idx3-ubyte.gz' | ||
LABELS_FILE = '-labels-idx1-ubyte.gz' | ||
IMAGES_FILE = '-images-idx3-ubyte.gz' | ||
IMAGE_SIZE = 28 | ||
NONE_LABEL = 255 | ||
|
||
class MnistExtractor(SourceExtractor): | ||
def __init__(self, path, subset=None): | ||
if not osp.isfile(path): | ||
raise FileNotFoundError("Can't read annotation file '%s'" % path) | ||
|
||
if not subset: | ||
file_name = osp.splitext(osp.basename(path))[0] | ||
if file_name.startswith('t10k'): | ||
subset = 'test' | ||
else: | ||
subset = file_name.split('-', maxsplit=1)[0] | ||
|
||
super().__init__(subset=subset) | ||
self._dataset_dir = osp.dirname(path) | ||
|
||
self._categories = self._load_categories() | ||
|
||
self._items = list(self._load_items(path).values()) | ||
|
||
def _load_categories(self): | ||
label_cat = LabelCategories() | ||
|
||
labels_file = osp.join(self._dataset_dir, 'labels.txt') | ||
if osp.isfile(labels_file): | ||
with open(labels_file, encoding='utf-8') as f: | ||
for line in f: | ||
line = line.strip() | ||
if not line: | ||
continue | ||
label_cat.add(line) | ||
else: | ||
for i in range(10): | ||
label_cat.add(str(i)) | ||
|
||
return { AnnotationType.label: label_cat } | ||
|
||
def _load_items(self, path): | ||
items = {} | ||
with gzip.open(path, 'rb') as lbpath: | ||
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) | ||
|
||
meta = [] | ||
metafile = osp.join(self._dataset_dir, self._subset + '-meta.gz') | ||
if osp.isfile(metafile): | ||
with gzip.open(metafile, 'rb') as f: | ||
meta = np.frombuffer(f.read(), dtype='<U32') | ||
meta = meta.reshape(len(labels), int(len(meta) / len(labels))) | ||
|
||
# support for single-channel image only | ||
images = None | ||
images_file = osp.join(self._dataset_dir, | ||
osp.basename(path).replace('labels-idx1', 'images-idx3')) | ||
if osp.isfile(images_file): | ||
with gzip.open(images_file, 'rb') as imgpath: | ||
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16) | ||
if len(meta) == 0 or len(meta[0]) < 2: | ||
images = images.reshape(len(labels), MnistPath.IMAGE_SIZE, | ||
MnistPath.IMAGE_SIZE) | ||
|
||
pix_num = 0 | ||
for i, annotation in enumerate(labels): | ||
annotations = [] | ||
label = annotation | ||
if label != MnistPath.NONE_LABEL: | ||
annotations.append(Label(label)) | ||
|
||
image = None | ||
if images is not None: | ||
if 0 < len(meta) and 1 < len(meta[i]): | ||
h, w = int(meta[i][-2]), int(meta[i][-1]) | ||
image = images[pix_num : pix_num + h * w].reshape(h, w) | ||
pix_num += h * w | ||
else: | ||
image = images[i].reshape(MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE) | ||
|
||
if 0 < len(meta) and (len(meta[i]) == 1 or len(meta[i]) == 3): | ||
i = meta[i][0] | ||
|
||
items[i] = DatasetItem(id=i, subset=self._subset, | ||
image=image, annotations=annotations) | ||
return items | ||
|
||
class MnistImporter(Importer): | ||
@classmethod | ||
def find_sources(cls, path): | ||
return cls._find_sources_recursive(path, '.gz', 'mnist', | ||
file_filter=lambda p: osp.basename(p).split('-')[1] == 'labels') | ||
|
||
class MnistConverter(Converter): | ||
DEFAULT_IMAGE_EXT = '.png' | ||
|
||
def apply(self): | ||
os.makedirs(self._save_dir, exist_ok=True) | ||
for subset_name, subset in self._extractor.subsets().items(): | ||
labels = [] | ||
images = np.array([]) | ||
item_ids = {} | ||
image_sizes = {} | ||
for item in subset: | ||
anns = [a.label for a in item.annotations | ||
if a.type == AnnotationType.label] | ||
label = 255 | ||
if anns: | ||
label = anns[0] | ||
labels.append(label) | ||
|
||
if item.id != str(len(labels) - 1): | ||
item_ids[len(labels) - 1] = item.id | ||
|
||
if item.has_image and self._save_images: | ||
image = item.image | ||
if not image.has_data: | ||
image_sizes[len(images) - 1] = [0, 0] | ||
else: | ||
image = image.data | ||
if image.shape[0] != MnistPath.IMAGE_SIZE or \ | ||
image.shape[1] != MnistPath.IMAGE_SIZE: | ||
image_sizes[len(labels) - 1] = [image.shape[0], image.shape[1]] | ||
images = np.append(images, image.reshape(-1).astype(np.uint8)) | ||
|
||
if subset_name == 'test': | ||
labels_file = osp.join(self._save_dir, | ||
MnistPath.TEST_LABELS_FILE) | ||
else: | ||
labels_file = osp.join(self._save_dir, | ||
subset_name + MnistPath.LABELS_FILE) | ||
self.save_annotations(labels_file, labels) | ||
|
||
if 0 < len(images): | ||
if subset_name == 'test': | ||
images_file = osp.join(self._save_dir, | ||
MnistPath.TEST_IMAGES_FILE) | ||
else: | ||
images_file = osp.join(self._save_dir, | ||
subset_name + MnistPath.IMAGES_FILE) | ||
self.save_images(images_file, images) | ||
|
||
# it is't in the original format, | ||
# this is for storng other names and sizes of images | ||
if len(item_ids) or len(image_sizes): | ||
meta = [] | ||
if len(item_ids) and len(image_sizes): | ||
# other names and sizes of images | ||
size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE] | ||
for i in range(len(labels)): | ||
w, h = image_sizes.get(i, size) | ||
meta.append([item_ids.get(i, i), w, h]) | ||
|
||
elif len(item_ids): | ||
# other names of images | ||
for i in range(len(labels)): | ||
meta.append([item_ids.get(i, i)]) | ||
|
||
elif len(image_sizes): | ||
# other sizes of images | ||
size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE] | ||
for i in range(len(labels)): | ||
meta.append(image_sizes.get(i, size)) | ||
|
||
metafile = osp.join(self._save_dir, subset_name + '-meta.gz') | ||
with gzip.open(metafile, 'wb') as f: | ||
f.write(np.array(meta, dtype='<U32').tobytes()) | ||
|
||
self.save_labels() | ||
|
||
def save_annotations(self, path, data): | ||
with gzip.open(path, 'wb') as f: | ||
# magic number = 0x0801 (2049, hexadecimal representation) | ||
# this is used to verify the file with MNIST mark data | ||
f.write(np.array([0x0801, len(data)], dtype='>i4').tobytes()) | ||
f.write(np.array(data, dtype='uint8').tobytes()) | ||
|
||
def save_images(self, path, data): | ||
with gzip.open(path, 'wb') as f: | ||
# magic number = 0x0803 (2051, hexadecimal representation), | ||
# this is used to verify the file with MNIST image data | ||
f.write(np.array([0x0803, len(data), MnistPath.IMAGE_SIZE, | ||
MnistPath.IMAGE_SIZE], dtype='>i4').tobytes()) | ||
f.write(np.array(data, dtype='uint8').tobytes()) | ||
|
||
def save_labels(self): | ||
labels_file = osp.join(self._save_dir, 'labels.txt') | ||
with open(labels_file, 'w', encoding='utf-8') as f: | ||
f.writelines(l.name + '\n' | ||
for l in self._extractor.categories().get( | ||
AnnotationType.label, LabelCategories()) | ||
) |
Oops, something went wrong.