diff --git a/src/lava/utils/dataloader/mnist.npy b/src/lava/utils/dataloader/mnist.npy deleted file mode 100644 index 8287ed8bc..000000000 --- a/src/lava/utils/dataloader/mnist.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c0efd66d9eb375c5cc3e05b7141901cf27413ee50d2d76b750e7a0454dcfa6ba -size 54950493 diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index 81eb4c0eb..72107c7b7 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -7,6 +7,17 @@ class MnistDataset: + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + "https://storage.googleapis.com/cvdf-datasets/mnist/", + ] + + files = [ + "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", + ] + def __init__(self, data_path=os.path.join(os.path.dirname(__file__), 'mnist.npy')): """data_path (str): Path to mnist.npy file containing the MNIST @@ -14,7 +25,7 @@ def __init__(self, data_path=os.path.join(os.path.dirname(__file__), if not os.path.exists(data_path): # Download MNIST from internet and convert it to .npy os.makedirs(os.path.join(os.path.dirname(__file__), 'temp'), - exist_ok=False) + exist_ok=True) MnistDataset. \ download_mnist(path=os.path.join( os.path.dirname(__file__), @@ -24,16 +35,35 @@ def __init__(self, data_path=os.path.join(os.path.dirname(__file__), # GUnzip, Parse and save MNIST data as .npy MnistDataset.decompress_convert_save( download_path=os.path.join(os.path.dirname(__file__), 'temp'), - save_path=os.path.dirname(data_path)) + save_path=data_path) self.data = np.load(data_path, allow_pickle=True) - # ToDo: Populate this method with a proper wget download from MNIST website @staticmethod def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')): - pass + import urllib.request + import urllib.error + + for file in MnistDataset.files: + err = None + for mirror in MnistDataset.mirrors: + try: + url = f"{mirror}{file}" + if url.lower().startswith("http"): + # Disabling security linter and using hardcoded + # URLs specified above + res = urllib.request.urlopen(url) # nosec + with open(os.path.join(path, file), "wb") as f: + f.write(res.read()) + break + else: + raise "Url does not start with http" + except urllib.error.URLError as exception: + err = exception + continue + else: + print("Failed to download mnist dataset") + raise err - # ToDo: Populate this method with proper code to decompress, parse, - # and save MNIST as mnist.npy @staticmethod def decompress_convert_save( download_path=os.path.join(os.path.dirname(__file__), 'temp'), @@ -42,16 +72,31 @@ def decompress_convert_save( download_path (str): path of downloaded raw MNIST dataset in IDX format save_path (str): path at which processed npy file will be saved + + After loading data = np.load(), data is a np.array of np.arrays. + train_imgs = data[0][0]; shape = 60000 x 28 x 28 + test_imgs = data[1][0]; shape = 10000 x 28 x 28 + train_labels = data[0][1]; shape = 60000 x 1 + test_labels = data[1][1]; shape = 10000 x 1 """ - # Gunzip, parse, and save as .npy - # Format of .npy: - # After loading data = np.load(), data is a np.array of np.arrays. - # train_imgs = data[0][0]; shape = 60000 x 28 x 28 - # test_imgs = data[1][0]; shape = 10000 x 28 x 28 - # train_labels = data[0][1]; shape = 60000 x 1 - # test_labels = data[1][1]; shape = 10000 x 1 - # save as 'mnist.npy' in save_path - pass + + import gzip + arrays = [] + for file in MnistDataset.files: + with gzip.open(os.path.join(download_path, file), "rb") as f: + if "images" in file: + arr = np.frombuffer(f.read(), np.uint8, offset=16) + arr = arr.reshape(-1, 28, 28) + else: + arr = np.frombuffer(f.read(), np.uint8, offset=8) + arrays.append(arr) + + np.save( + save_path, + np.array( + [[arrays[0], arrays[1]], [arrays[2], arrays[3]]], + dtype="object"), + ) @property def train_images(self):