From 5d3e20f0c3fab342191fa6e62ba3150cc56cb175 Mon Sep 17 00:00:00 2001 From: Alexander Koenig Date: Mon, 11 Apr 2022 13:13:39 +0200 Subject: [PATCH 1/5] Add Nvidia DALI tutorial --- examples/11.DALI_Squirrel_Integration.py | 141 +++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 examples/11.DALI_Squirrel_Integration.py diff --git a/examples/11.DALI_Squirrel_Integration.py b/examples/11.DALI_Squirrel_Integration.py new file mode 100644 index 0000000..2893945 --- /dev/null +++ b/examples/11.DALI_Squirrel_Integration.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import time +import types +from collections import defaultdict +from typing import Any, Dict, List + +import cupy +import numpy as np +import nvidia.dali.fn as fn +import nvidia.dali.types as types +import torch +from nvidia.dali import pipeline_def +from nvidia.dali.plugin.pytorch import DALIGenericIterator +from squirrel.catalog import Catalog +from squirrel.iterstream import Composable + +#### INSTALLATION +# pip install cupy-cuda110 squirrel-core squirrel-datasets-core +# pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 +#### INSTALLATION END + +# some metadata for cifar100 +BATCH_SIZE = 16 +DS_SPLIT = "train" +DS = "cifar100" +DS_LEN = 50000 +DS_VERSION = 1 # uses Huggingface version of the dataset +NUM_OUTPUTS = 3 # img, fine_label, coarse_label + + +def cupy_collate(records: List[Dict[str, Any]]) -> List[cupy.ndarray]: + """Converts a batch of a list of dicts to a list of cupy arrays. + + Args: + records (List[Dict[str, Any]]): A batch to convert. + + Returns: + List[cupy.ndarray]: Re-formatted batch of cupy arrays. For CIFAR-100, + this returns List[batch_imgs, batch_fine_labels, batch_coarse_labels]. + """ + batch = defaultdict(list) + for elem in records: + for k, v in elem.items(): + batch[k].append(np.array(v)) + return [cupy.asarray(np.array(v), dtype=np.uint8) for v in batch.values()] + + +class SquirrelGpuIterator(object): + """External source iterator for a Squirrel driver. By using the cupy_collate function + for batch collation, we support the cuda array interface. Hence, data is returned + on the GPU. The below example was mainly based on the below tutorial. + https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html + """ + + def __init__( + self, + iterable: Composable, + batch_size: int, + ds_len: int, + ): + """Constructs the SquirrelGpuIterator. + + Args: + iterable (Composable): A Squirrel iterable object you'd like to use with DALI. + batch_size (int): How many samples are contained in one batch. + ds_len (int): How many samples are in your dataset. We need to know this here, because, as of now, + Squirrel drivers do not implement the __len__() method. + """ + self.batch_size = batch_size + self.it = iter(iterable.batched(self.batch_size, collation_fn=cupy_collate, drop_last_if_not_full=False)) + self.ds_len = ds_len + + def __iter__(self) -> SquirrelGpuIterator: + """Returns the iterator and sets variables that DALI uses internally for tracking iteration progress. + + Returns: + SquirrelGpuIterator: The iterator object. + """ + self.i = 0 + self.n = self.ds_len + return self + + def __next__(self) -> List[cupy.ndarray]: + """Advances the index by batch_size and returns the next item of the iterator. + + Returns: + List[cupy.ndarray]: The next batch of your dataset (on GPU). + """ + + self.i = (self.i + self.batch_size) % self.n + return next(self.it) + + +cat = Catalog.from_plugins() +it = cat[DS][DS_VERSION].get_driver().get_iter(DS_SPLIT) +source = SquirrelGpuIterator(it, BATCH_SIZE, DS_LEN) + + +@pipeline_def +def pipeline(): + imgs, fine_labels, coarse_labels = fn.external_source( + source=it, + num_outputs=NUM_OUTPUTS, + device="gpu", + dtype=types.UINT8, + ) + enhanced = fn.brightness_contrast(imgs, contrast=2) + return enhanced, fine_labels, coarse_labels + + +pipe = pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0) +pipe.build() + +# this is the DALI equivalent of torch.utils.data.DataLoader +dali_iter = DALIGenericIterator([pipe], ["img", "fine_label", "coarse_label"]) + +print("Iterating over dataset ...") +t = time.time() +for idx, item in enumerate(dali_iter): + if idx == 0: + # item is a list of length 1 + assert type(item) == list + assert len(item) == 1 + + # let's see what's in the list, it's a dict with our data! + data = item[0] + assert type(data) == dict + + # the data is already in torch format and on GPU! + assert type(data["img"]) == torch.Tensor + assert data["img"].shape == torch.Size([BATCH_SIZE, 32, 32, 3]) + assert data["img"].device == torch.device("cuda:0") + assert data["fine_label"].shape[0] == BATCH_SIZE + + print("Data in torch format!") + +samples_ps = DS_LEN / (time.time() - t) +print(f"Iteration speed: {samples_ps} samples per second") + +print("Done. Have a nice day.") From 7919a8ea680a1a1438c1de00d92827de883002b2 Mon Sep 17 00:00:00 2001 From: Alexander Koenig Date: Mon, 11 Apr 2022 13:37:34 +0200 Subject: [PATCH 2/5] Formatting --- examples/11.DALI_Squirrel_Integration.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/11.DALI_Squirrel_Integration.py b/examples/11.DALI_Squirrel_Integration.py index 2893945..6ada934 100644 --- a/examples/11.DALI_Squirrel_Integration.py +++ b/examples/11.DALI_Squirrel_Integration.py @@ -1,9 +1,8 @@ from __future__ import annotations import time -import types from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import cupy import numpy as np @@ -11,21 +10,21 @@ import nvidia.dali.types as types import torch from nvidia.dali import pipeline_def +from nvidia.dali.pipeline import DataNode from nvidia.dali.plugin.pytorch import DALIGenericIterator from squirrel.catalog import Catalog from squirrel.iterstream import Composable -#### INSTALLATION +# INSTALLATION # pip install cupy-cuda110 squirrel-core squirrel-datasets-core # pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 -#### INSTALLATION END # some metadata for cifar100 BATCH_SIZE = 16 DS_SPLIT = "train" DS = "cifar100" DS_LEN = 50000 -DS_VERSION = 1 # uses Huggingface version of the dataset +DS_VERSION = 1 # uses Huggingface version of the dataset NUM_OUTPUTS = 3 # img, fine_label, coarse_label @@ -98,7 +97,12 @@ def __next__(self) -> List[cupy.ndarray]: @pipeline_def -def pipeline(): +def pipeline() -> Tuple[DataNode]: + """DALI pipeline defining the data processing graph. + + Returns: + Tuple[DataNode]: The outputs of the operators. + """ imgs, fine_labels, coarse_labels = fn.external_source( source=it, num_outputs=NUM_OUTPUTS, From 265490f112010fb8223e17e8dbb7e1a638ca7e40 Mon Sep 17 00:00:00 2001 From: Alexander Koenig Date: Mon, 11 Apr 2022 16:19:06 +0200 Subject: [PATCH 3/5] Fix iterator --- examples/11.DALI_Squirrel_Integration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/11.DALI_Squirrel_Integration.py b/examples/11.DALI_Squirrel_Integration.py index 6ada934..3a77c25 100644 --- a/examples/11.DALI_Squirrel_Integration.py +++ b/examples/11.DALI_Squirrel_Integration.py @@ -16,11 +16,11 @@ from squirrel.iterstream import Composable # INSTALLATION -# pip install cupy-cuda110 squirrel-core squirrel-datasets-core +# pip install cupy-cuda113 squirrel-core squirrel-datasets-core # pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 # some metadata for cifar100 -BATCH_SIZE = 16 +BATCH_SIZE = 256 DS_SPLIT = "train" DS = "cifar100" DS_LEN = 50000 @@ -104,7 +104,7 @@ def pipeline() -> Tuple[DataNode]: Tuple[DataNode]: The outputs of the operators. """ imgs, fine_labels, coarse_labels = fn.external_source( - source=it, + source=source, num_outputs=NUM_OUTPUTS, device="gpu", dtype=types.UINT8, @@ -112,7 +112,7 @@ def pipeline() -> Tuple[DataNode]: enhanced = fn.brightness_contrast(imgs, contrast=2) return enhanced, fine_labels, coarse_labels - +print("Building pipeline ...") pipe = pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0) pipe.build() From 8e70e723f963bf891ea8f3975e723f70160da00e Mon Sep 17 00:00:00 2001 From: Alexander Koenig Date: Wed, 4 May 2022 15:42:48 +0200 Subject: [PATCH 4/5] Lint --- examples/11.DALI_Squirrel_Integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/11.DALI_Squirrel_Integration.py b/examples/11.DALI_Squirrel_Integration.py index 3a77c25..b2d53cf 100644 --- a/examples/11.DALI_Squirrel_Integration.py +++ b/examples/11.DALI_Squirrel_Integration.py @@ -112,6 +112,7 @@ def pipeline() -> Tuple[DataNode]: enhanced = fn.brightness_contrast(imgs, contrast=2) return enhanced, fine_labels, coarse_labels + print("Building pipeline ...") pipe = pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0) pipe.build() From 18128391b01b7d383af58668123d5a423f93d3a1 Mon Sep 17 00:00:00 2001 From: Winfried Loetzsch Date: Wed, 4 May 2022 16:43:27 +0200 Subject: [PATCH 5/5] Add instructions --- examples/11.DALI_Squirrel_Integration.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/11.DALI_Squirrel_Integration.py b/examples/11.DALI_Squirrel_Integration.py index b2d53cf..0714d33 100644 --- a/examples/11.DALI_Squirrel_Integration.py +++ b/examples/11.DALI_Squirrel_Integration.py @@ -15,9 +15,24 @@ from squirrel.catalog import Catalog from squirrel.iterstream import Composable -# INSTALLATION -# pip install cupy-cuda113 squirrel-core squirrel-datasets-core -# pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 +""" + INSTRUCTIONS: This tutorial showcases how to use Nvidia DALI in combination with squirrel + by defining an "external source" for a DALI Pipeline + (https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html) + DALI is a library that is highly optimised for computations on the GPU. + With this integration we get very fast image augmentations and other image processing that DALI offers, + right out of the box. + + SYSTEM: Code was tested on the following system: + - OS: Ubuntu 20.04.3 LTS + - NVIDIA-SMI 450.119.04, Driver Version: 450.119.04, CUDA Version: 11.0 + - One Tesla T4 GPU + - squirrel-core==0.12.3, squirrel-datasets-core==0.1.2 + + INSTALLATION + pip install cupy-cuda113 squirrel-core squirrel-datasets-core + pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110 + """ # some metadata for cifar100 BATCH_SIZE = 256