forked from deep-diver/AlexNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar10_utils.py
139 lines (107 loc) · 4.74 KB
/
cifar10_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm
import tarfile
import pickle
import numpy as np
import skimage
import skimage.io
import skimage.transform
class DownloadProgress(tqdm):
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
def download(dataset_folder_path):
if not isfile('cifar-10-python.tar.gz'):
with DownloadProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Dataset') as pbar:
urlretrieve(
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
'cifar-10-python.tar.gz',
pbar.hook)
else:
print('cifar-10-python.tar.gz already exists')
if not isdir(dataset_folder_path):
with tarfile.open('cifar-10-python.tar.gz') as tar:
tar.extractall()
tar.close()
else:
print('cifar10 dataset already exists')
def convert_to_imagenet_size(images):
tmp_images = []
for image in images:
tmp_image = skimage.transform.resize(image, (224, 224), mode='constant')
tmp_images.append(tmp_image)
return np.array(tmp_images)
def load_cifar10_batch(dataset_folder_path, batch_id):
with open(dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
# note the encoding type is 'latin1'
batch = pickle.load(file, encoding='latin1')
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = batch['labels']
return features, labels
def one_hot_encode(x):
encoded = np.zeros((len(x), 10))
for idx, val in enumerate(x):
encoded[idx][val] = 1
return encoded
def _preprocess_and_save(one_hot_encode, features, labels, filename):
labels = one_hot_encode(labels)
pickle.dump((features, labels), open(filename, 'wb'))
def preprocess_and_save_data(dataset_folder_path):
n_batches = 5
valid_features = []
valid_labels = []
for batch_i in range(1, n_batches + 1):
features, labels = load_cifar10_batch(dataset_folder_path, batch_i)
# find index to be the point as validation data in the whole dataset of the batch (10%)
index_of_validation = int(len(features) * 0.1)
# preprocess the 90% of the whole dataset of the batch
# - normalize the features
# - one_hot_encode the lables
# - save in a new file named, "preprocess_batch_" + batch_number
# - each file for each batch
_preprocess_and_save(one_hot_encode,
features[:-index_of_validation], labels[:-index_of_validation],
'preprocess_batch_' + str(batch_i) + '.p')
# unlike the training dataset, validation dataset will be added through all batch dataset
# - take 10% of the whold dataset of the batch
# - add them into a list of
# - valid_features
# - valid_labels
valid_features.extend(features[-index_of_validation:])
valid_labels.extend(labels[-index_of_validation:])
# preprocess the all stacked validation dataset
_preprocess_and_save(one_hot_encode,
np.array(valid_features), np.array(valid_labels),
'preprocess_validation.p')
# load the test dataset
with open(dataset_folder_path + '/test_batch', mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
# preprocess the testing data
test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
test_labels = batch['labels']
# Preprocess and Save all testing data
_preprocess_and_save(one_hot_encode,
np.array(test_features), np.array(test_labels),
'preprocess_testing.p')
def batch_features_labels(features, labels, batch_size):
"""
Split features and labels into batches
"""
for start in range(0, len(features), batch_size):
end = min(start + batch_size, len(features))
yield features[start:end], labels[start:end]
def load_preprocess_training_batch(batch_id, batch_size):
"""
Load the Preprocessed Training data and return them in batches of <batch_size> or less
"""
filename = 'preprocess_batch_' + str(batch_id) + '.p'
features, labels = pickle.load(open(filename, mode='rb'))
tmpFeatures = []
for feature in features:
tmpFeature = skimage.transform.resize(feature, (224, 224), mode='constant')
tmpFeatures.append(tmpFeature)
# Return the training data in batches of size <batch_size> or less
return batch_features_labels(tmpFeatures, labels, batch_size)