Skip to content

Commit

Permalink
Support for MNIST dataset format (#234)
Browse files Browse the repository at this point in the history
* add mnist format

* add mnist csv format

* add mnist to documentation
  • Loading branch information
yasakova-anastasia authored May 14, 2021
1 parent ef003ca commit ca9f78e
Show file tree
Hide file tree
Showing 14 changed files with 940 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for CIFAR-10/100 dataset format (<https://github.com/openvinotoolkit/datumaro/pull/225>)
- Support COCO panoptic and stuff format (<https://github.com/openvinotoolkit/datumaro/pull/210>)
- Documentation file and integration tests for Pascal VOC format (<https://github.com/openvinotoolkit/datumaro/pull/228>)
- Support for MNIST and MNIST in CSV dataset formats (<https://github.com/openvinotoolkit/datumaro/pull/234>)

### Changed
- LabelMe format saves dataset items with their relative paths by subsets without changing names (<https://github.com/openvinotoolkit/datumaro/pull/200>)
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ CVAT annotations ---> Publication, statistics etc.
- [MOTS PNG](https://www.vision.rwth-aachen.de/page/mots)
- [ImageNet](http://image-net.org/)
- [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) (`classification`)
- [MNIST](http://yann.lecun.com/exdb/mnist/) (`classification`)
- [MNIST in CSV](https://pjreddie.com/projects/mnist-in-csv/) (`classification`)
- [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/)
- [CVAT](https://github.com/opencv/cvat/blob/develop/cvat/apps/documentation/xml_format.md)
- [LabelMe](http://labelme.csail.mit.edu/Release3.0)
Expand Down
170 changes: 170 additions & 0 deletions datumaro/plugins/mnist_csv_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import os
import os.path as osp

import numpy as np
from datumaro.components.converter import Converter
from datumaro.components.extractor import (AnnotationType, DatasetItem,
Importer, Label, LabelCategories, SourceExtractor)


class MnistCsvPath:
IMAGE_SIZE = 28
NONE_LABEL = -1

class MnistCsvExtractor(SourceExtractor):
def __init__(self, path, subset=None):
if not osp.isfile(path):
raise FileNotFoundError("Can't read annotation file '%s'" % path)

if not subset:
file_name = osp.splitext(osp.basename(path))[0]
subset = file_name.rsplit('_', maxsplit=1)[-1]

super().__init__(subset=subset)
self._dataset_dir = osp.dirname(path)

self._categories = self._load_categories()

self._items = list(self._load_items(path).values())

def _load_categories(self):
label_cat = LabelCategories()

labels_file = osp.join(self._dataset_dir, 'labels.txt')
if osp.isfile(labels_file):
with open(labels_file, encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
label_cat.add(line)
else:
for i in range(10):
label_cat.add(str(i))

return { AnnotationType.label: label_cat }

def _load_items(self, path):
items = {}
with open(path, 'r', encoding='utf-8') as f:
annotation_table = f.readlines()

metafile = osp.join(self._dataset_dir, 'meta_%s.csv' % self._subset)
meta = []
if osp.isfile(metafile):
with open(metafile, 'r', encoding='utf-8') as f:
meta = f.readlines()

for i, data in enumerate(annotation_table):
data = data.split(',')
item_anno = []
label = int(data[0])
if label != MnistCsvPath.NONE_LABEL:
item_anno.append(Label(label))

if 0 < len(meta):
meta[i] = meta[i].strip().split(',')

# support for single-channel image only
image = None
if 1 < len(data):
if 0 < len(meta) and 1 < len(meta[i]):
image = np.array([int(pix) for pix in data[1:]],
dtype='uint8').reshape(int(meta[i][-2]), int(meta[i][-1]))
else:
image = np.array([int(pix) for pix in data[1:]],
dtype='uint8').reshape(28, 28)

if 0 < len(meta) and len(meta[i]) in [1, 3]:
i = meta[i][0]

items[i] = DatasetItem(id=i, subset=self._subset,
image=image, annotations=item_anno)
return items

class MnistCsvImporter(Importer):
@classmethod
def find_sources(cls, path):
return cls._find_sources_recursive(path, '.csv', 'mnist_csv',
file_filter=lambda p: not osp.basename(p).startswith('meta'))

class MnistCsvConverter(Converter):
DEFAULT_IMAGE_EXT = '.png'

def apply(self):
os.makedirs(self._save_dir, exist_ok=True)
for subset_name, subset in self._extractor.subsets().items():
data = []
item_ids = {}
image_sizes = {}
for item in subset:
anns = [a.label for a in item.annotations
if a.type == AnnotationType.label]
label = MnistCsvPath.NONE_LABEL
if anns:
label = anns[0]

if item.has_image and self._save_images:
image = item.image
if not image.has_data:
data.append([label, None])
else:
if image.data.shape[0] != MnistCsvPath.IMAGE_SIZE or \
image.data.shape[1] != MnistCsvPath.IMAGE_SIZE:
image_sizes[len(data)] = [image.data.shape[0],
image.data.shape[1]]
image = image.data.reshape(-1).astype(np.uint8).tolist()
image.insert(0, label)
data.append(image)
else:
data.append([label])

if item.id != str(len(data) - 1):
item_ids[len(data) - 1] = item.id

anno_file = osp.join(self._save_dir, 'mnist_%s.csv' % subset_name)
self.save_in_csv(anno_file, data)

# it is't in the original format,
# this is for storng other names and sizes of images
if len(item_ids) or len(image_sizes):
meta = []
if len(item_ids) and len(image_sizes):
# other names and sizes of images
size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE]
for i in range(len(data)):
w, h = image_sizes.get(i, size)
meta.append([item_ids.get(i, i), w, h])

elif len(item_ids):
# other names of images
for i in range(len(data)):
meta.append([item_ids.get(i, i)])

elif len(image_sizes):
# other sizes of images
size = [MnistCsvPath.IMAGE_SIZE, MnistCsvPath.IMAGE_SIZE]
for i in range(len(data)):
meta.append(image_sizes.get(i, size))

metafile = osp.join(self._save_dir, 'meta_%s.csv' % subset_name)
self.save_in_csv(metafile, meta)

self.save_labels()

def save_in_csv(self, path, data):
with open(path, 'w', encoding='utf-8') as f:
for row in data:
f.write(','.join([str(p) for p in row]) + "\n")

def save_labels(self):
labels_file = osp.join(self._save_dir, 'labels.txt')
with open(labels_file, 'w', encoding='utf-8') as f:
f.writelines(l.name + '\n'
for l in self._extractor.categories().get(
AnnotationType.label, LabelCategories())
)
209 changes: 209 additions & 0 deletions datumaro/plugins/mnist_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (C) 2021 Intel Corporation
#
# SPDX-License-Identifier: MIT

import gzip
import os
import os.path as osp

import numpy as np
from datumaro.components.converter import Converter
from datumaro.components.extractor import (AnnotationType, DatasetItem,
Importer, Label, LabelCategories, SourceExtractor)


class MnistPath:
TEST_LABELS_FILE = 't10k-labels-idx1-ubyte.gz'
TEST_IMAGES_FILE = 't10k-images-idx3-ubyte.gz'
LABELS_FILE = '-labels-idx1-ubyte.gz'
IMAGES_FILE = '-images-idx3-ubyte.gz'
IMAGE_SIZE = 28
NONE_LABEL = 255

class MnistExtractor(SourceExtractor):
def __init__(self, path, subset=None):
if not osp.isfile(path):
raise FileNotFoundError("Can't read annotation file '%s'" % path)

if not subset:
file_name = osp.splitext(osp.basename(path))[0]
if file_name.startswith('t10k'):
subset = 'test'
else:
subset = file_name.split('-', maxsplit=1)[0]

super().__init__(subset=subset)
self._dataset_dir = osp.dirname(path)

self._categories = self._load_categories()

self._items = list(self._load_items(path).values())

def _load_categories(self):
label_cat = LabelCategories()

labels_file = osp.join(self._dataset_dir, 'labels.txt')
if osp.isfile(labels_file):
with open(labels_file, encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
label_cat.add(line)
else:
for i in range(10):
label_cat.add(str(i))

return { AnnotationType.label: label_cat }

def _load_items(self, path):
items = {}
with gzip.open(path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)

meta = []
metafile = osp.join(self._dataset_dir, self._subset + '-meta.gz')
if osp.isfile(metafile):
with gzip.open(metafile, 'rb') as f:
meta = np.frombuffer(f.read(), dtype='<U32')
meta = meta.reshape(len(labels), int(len(meta) / len(labels)))

# support for single-channel image only
images = None
images_file = osp.join(self._dataset_dir,
osp.basename(path).replace('labels-idx1', 'images-idx3'))
if osp.isfile(images_file):
with gzip.open(images_file, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16)
if len(meta) == 0 or len(meta[0]) < 2:
images = images.reshape(len(labels), MnistPath.IMAGE_SIZE,
MnistPath.IMAGE_SIZE)

pix_num = 0
for i, annotation in enumerate(labels):
annotations = []
label = annotation
if label != MnistPath.NONE_LABEL:
annotations.append(Label(label))

image = None
if images is not None:
if 0 < len(meta) and 1 < len(meta[i]):
h, w = int(meta[i][-2]), int(meta[i][-1])
image = images[pix_num : pix_num + h * w].reshape(h, w)
pix_num += h * w
else:
image = images[i].reshape(MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE)

if 0 < len(meta) and (len(meta[i]) == 1 or len(meta[i]) == 3):
i = meta[i][0]

items[i] = DatasetItem(id=i, subset=self._subset,
image=image, annotations=annotations)
return items

class MnistImporter(Importer):
@classmethod
def find_sources(cls, path):
return cls._find_sources_recursive(path, '.gz', 'mnist',
file_filter=lambda p: osp.basename(p).split('-')[1] == 'labels')

class MnistConverter(Converter):
DEFAULT_IMAGE_EXT = '.png'

def apply(self):
os.makedirs(self._save_dir, exist_ok=True)
for subset_name, subset in self._extractor.subsets().items():
labels = []
images = np.array([])
item_ids = {}
image_sizes = {}
for item in subset:
anns = [a.label for a in item.annotations
if a.type == AnnotationType.label]
label = 255
if anns:
label = anns[0]
labels.append(label)

if item.id != str(len(labels) - 1):
item_ids[len(labels) - 1] = item.id

if item.has_image and self._save_images:
image = item.image
if not image.has_data:
image_sizes[len(images) - 1] = [0, 0]
else:
image = image.data
if image.shape[0] != MnistPath.IMAGE_SIZE or \
image.shape[1] != MnistPath.IMAGE_SIZE:
image_sizes[len(labels) - 1] = [image.shape[0], image.shape[1]]
images = np.append(images, image.reshape(-1).astype(np.uint8))

if subset_name == 'test':
labels_file = osp.join(self._save_dir,
MnistPath.TEST_LABELS_FILE)
else:
labels_file = osp.join(self._save_dir,
subset_name + MnistPath.LABELS_FILE)
self.save_annotations(labels_file, labels)

if 0 < len(images):
if subset_name == 'test':
images_file = osp.join(self._save_dir,
MnistPath.TEST_IMAGES_FILE)
else:
images_file = osp.join(self._save_dir,
subset_name + MnistPath.IMAGES_FILE)
self.save_images(images_file, images)

# it is't in the original format,
# this is for storng other names and sizes of images
if len(item_ids) or len(image_sizes):
meta = []
if len(item_ids) and len(image_sizes):
# other names and sizes of images
size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE]
for i in range(len(labels)):
w, h = image_sizes.get(i, size)
meta.append([item_ids.get(i, i), w, h])

elif len(item_ids):
# other names of images
for i in range(len(labels)):
meta.append([item_ids.get(i, i)])

elif len(image_sizes):
# other sizes of images
size = [MnistPath.IMAGE_SIZE, MnistPath.IMAGE_SIZE]
for i in range(len(labels)):
meta.append(image_sizes.get(i, size))

metafile = osp.join(self._save_dir, subset_name + '-meta.gz')
with gzip.open(metafile, 'wb') as f:
f.write(np.array(meta, dtype='<U32').tobytes())

self.save_labels()

def save_annotations(self, path, data):
with gzip.open(path, 'wb') as f:
# magic number = 0x0801 (2049, hexadecimal representation)
# this is used to verify the file with MNIST mark data
f.write(np.array([0x0801, len(data)], dtype='>i4').tobytes())
f.write(np.array(data, dtype='uint8').tobytes())

def save_images(self, path, data):
with gzip.open(path, 'wb') as f:
# magic number = 0x0803 (2051, hexadecimal representation),
# this is used to verify the file with MNIST image data
f.write(np.array([0x0803, len(data), MnistPath.IMAGE_SIZE,
MnistPath.IMAGE_SIZE], dtype='>i4').tobytes())
f.write(np.array(data, dtype='uint8').tobytes())

def save_labels(self):
labels_file = osp.join(self._save_dir, 'labels.txt')
with open(labels_file, 'w', encoding='utf-8') as f:
f.writelines(l.name + '\n'
for l in self._extractor.categories().get(
AnnotationType.label, LabelCategories())
)
Loading

0 comments on commit ca9f78e

Please sign in to comment.