From a52e860edcbb822312470b4074c4ad79b0ce5e28 Mon Sep 17 00:00:00 2001 From: Ismael Balafrej Date: Tue, 30 Nov 2021 15:03:27 +0100 Subject: [PATCH 1/5] Automatic download of mnist dataset --- src/lava/utils/dataloader/mnist.npy | 3 -- src/lava/utils/dataloader/mnist.py | 67 ++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 18 deletions(-) delete mode 100644 src/lava/utils/dataloader/mnist.npy diff --git a/src/lava/utils/dataloader/mnist.npy b/src/lava/utils/dataloader/mnist.npy deleted file mode 100644 index 8287ed8bc..000000000 --- a/src/lava/utils/dataloader/mnist.npy +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c0efd66d9eb375c5cc3e05b7141901cf27413ee50d2d76b750e7a0454dcfa6ba -size 54950493 diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index 81eb4c0eb..640ee3157 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -7,6 +7,17 @@ class MnistDataset: + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + "https://storage.googleapis.com/cvdf-datasets/mnist/", + ] + + files = [ + "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", + ] + def __init__(self, data_path=os.path.join(os.path.dirname(__file__), 'mnist.npy')): """data_path (str): Path to mnist.npy file containing the MNIST @@ -14,7 +25,7 @@ def __init__(self, data_path=os.path.join(os.path.dirname(__file__), if not os.path.exists(data_path): # Download MNIST from internet and convert it to .npy os.makedirs(os.path.join(os.path.dirname(__file__), 'temp'), - exist_ok=False) + exist_ok=True) MnistDataset. \ download_mnist(path=os.path.join( os.path.dirname(__file__), @@ -24,16 +35,29 @@ def __init__(self, data_path=os.path.join(os.path.dirname(__file__), # GUnzip, Parse and save MNIST data as .npy MnistDataset.decompress_convert_save( download_path=os.path.join(os.path.dirname(__file__), 'temp'), - save_path=os.path.dirname(data_path)) + save_path=data_path) self.data = np.load(data_path, allow_pickle=True) - # ToDo: Populate this method with a proper wget download from MNIST website @staticmethod def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')): - pass + import urllib.request + import urllib.error + + for file in MnistDataset.files: + err = None + for mirror in MnistDataset.mirrors: + try: + res = urllib.request.urlopen(f"{mirror}{file}") + with open(os.path.join(path, file), "wb") as f: + f.write(res.read()) + break + except urllib.error.URLError as exception: + err = exception + continue + else: + print("Failed to download mnist dataset") + raise err - # ToDo: Populate this method with proper code to decompress, parse, - # and save MNIST as mnist.npy @staticmethod def decompress_convert_save( download_path=os.path.join(os.path.dirname(__file__), 'temp'), @@ -42,16 +66,29 @@ def decompress_convert_save( download_path (str): path of downloaded raw MNIST dataset in IDX format save_path (str): path at which processed npy file will be saved + + After loading data = np.load(), data is a np.array of np.arrays. + train_imgs = data[0][0]; shape = 60000 x 28 x 28 + test_imgs = data[1][0]; shape = 10000 x 28 x 28 + train_labels = data[0][1]; shape = 60000 x 1 + test_labels = data[1][1]; shape = 10000 x 1 """ - # Gunzip, parse, and save as .npy - # Format of .npy: - # After loading data = np.load(), data is a np.array of np.arrays. - # train_imgs = data[0][0]; shape = 60000 x 28 x 28 - # test_imgs = data[1][0]; shape = 10000 x 28 x 28 - # train_labels = data[0][1]; shape = 60000 x 1 - # test_labels = data[1][1]; shape = 10000 x 1 - # save as 'mnist.npy' in save_path - pass + + import gzip + arrays = [] + for file in MnistDataset.files: + with gzip.open(os.path.join(download_path, file), "rb") as f: + if "images" in file: + arr = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28) + else: + arr = np.frombuffer(f.read(), np.uint8, offset=8) + arrays.append(arr) + + np.save( + save_path, + np.array([[arrays[0], arrays[1]], [arrays[2], arrays[3]]], dtype="object"), + ) + @property def train_images(self): From 088acb22bbb3b3237700f1e8b8f9421cd88d1ca5 Mon Sep 17 00:00:00 2001 From: Ismael Balafrej Date: Tue, 30 Nov 2021 15:29:02 +0100 Subject: [PATCH 2/5] Reduced max line-length of mnist dataloder for linter --- src/lava/utils/dataloader/mnist.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index 640ee3157..1a4cf2b23 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -79,17 +79,19 @@ def decompress_convert_save( for file in MnistDataset.files: with gzip.open(os.path.join(download_path, file), "rb") as f: if "images" in file: - arr = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28) + arr = np.frombuffer(f.read(), np.uint8, offset=16) + arr = arr.reshape(-1, 28, 28) else: arr = np.frombuffer(f.read(), np.uint8, offset=8) arrays.append(arr) np.save( save_path, - np.array([[arrays[0], arrays[1]], [arrays[2], arrays[3]]], dtype="object"), + np.array( + [[arrays[0], arrays[1]], [arrays[2], arrays[3]]], + dtype="object"), ) - @property def train_images(self): return self.data[0][0].reshape((60000, 784)) From d55ec3ad0e3e261e23187dca88e2f151bab976d7 Mon Sep 17 00:00:00 2001 From: Ismael Balafrej Date: Tue, 30 Nov 2021 18:08:06 +0100 Subject: [PATCH 3/5] Removing security linting from mnist dataset download --- src/lava/utils/dataloader/mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index 1a4cf2b23..f4b1476e0 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -47,7 +47,9 @@ def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')): err = None for mirror in MnistDataset.mirrors: try: - res = urllib.request.urlopen(f"{mirror}{file}") + # Disabling security linter and using hardcoded + # URLs specified above + res = urllib.request.urlopen(f"{mirror}{file}") # nosec with open(os.path.join(path, file), "wb") as f: f.write(res.read()) break From 2ad0cea9798dc39c0401d6f915bd643adafdff0c Mon Sep 17 00:00:00 2001 From: Ismael Balafrej Date: Tue, 30 Nov 2021 19:21:59 +0100 Subject: [PATCH 4/5] Forcing urls in mnist dataloader to start with http --- src/lava/utils/dataloader/mnist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index f4b1476e0..d2488bc03 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -49,7 +49,9 @@ def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')): try: # Disabling security linter and using hardcoded # URLs specified above - res = urllib.request.urlopen(f"{mirror}{file}") # nosec + url = f"{mirror}{file}" + assert url.lower().startswith("http") + res = urllib.request.urlopen(url) # nosec with open(os.path.join(path, file), "wb") as f: f.write(res.read()) break From b3b9687bde25d4eb3315e34ca55a89edfded531e Mon Sep 17 00:00:00 2001 From: Ismael Balafrej Date: Tue, 30 Nov 2021 20:41:48 +0100 Subject: [PATCH 5/5] Fixing linting security issue with urllib --- src/lava/utils/dataloader/mnist.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lava/utils/dataloader/mnist.py b/src/lava/utils/dataloader/mnist.py index d2488bc03..72107c7b7 100644 --- a/src/lava/utils/dataloader/mnist.py +++ b/src/lava/utils/dataloader/mnist.py @@ -47,14 +47,16 @@ def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')): err = None for mirror in MnistDataset.mirrors: try: - # Disabling security linter and using hardcoded - # URLs specified above url = f"{mirror}{file}" - assert url.lower().startswith("http") - res = urllib.request.urlopen(url) # nosec - with open(os.path.join(path, file), "wb") as f: - f.write(res.read()) - break + if url.lower().startswith("http"): + # Disabling security linter and using hardcoded + # URLs specified above + res = urllib.request.urlopen(url) # nosec + with open(os.path.join(path, file), "wb") as f: + f.write(res.read()) + break + else: + raise "Url does not start with http" except urllib.error.URLError as exception: err = exception continue