From a11bd49942b6598eac35cc01bd127fe52e63c9ae Mon Sep 17 00:00:00 2001 From: Andrew Ilyas Date: Wed, 26 Jan 2022 07:53:03 -0500 Subject: [PATCH 1/4] update transforms and tests --- ffcv/transforms/random_resized_crop.py | 25 +++++++----- ffcv/transforms/translate.py | 8 +--- tests/test_augmentations.py | 54 ++++++++------------------ 3 files changed, 34 insertions(+), 53 deletions(-) diff --git a/ffcv/transforms/random_resized_crop.py b/ffcv/transforms/random_resized_crop.py index dd38dfe9..24d2403a 100644 --- a/ffcv/transforms/random_resized_crop.py +++ b/ffcv/transforms/random_resized_crop.py @@ -8,9 +8,13 @@ from ..pipeline.allocation_query import AllocationQuery from ..pipeline.operation import Operation from ..pipeline.state import State +from ..pipeline.compiler import Compiler class RandomResizedCrop(Operation): - """Crop a random portion of image with random aspect ratio and resize it to a given size. + """Crop a random portion of image with random aspect ratio and resize it to + a given size. Chances are you do not want to use this augmentation and + instead want to include RRC as part of the decoder, by using the + :cla:`~ffcv.fields.rgb_image.ResizedCropRGBImageDecoder` class. Parameters ---------- @@ -28,19 +32,20 @@ def __init__(self, scale: Tuple[float, float], ratio: Tuple[float, float], size: self.size = size def generate_code(self) -> Callable: - scale, ratio = self.scale, self.ratio + scale, ratio = np.array(self.scale), np.array(self.ratio) + my_range = Compiler.get_iterator() def random_resized_crop(im, dst): - i, j, h, w = fast_crop.get_random_crop(im.shape[0], - im.shape[1], - scale, - ratio) - fast_crop.resize_crop(im, i, i + h, j, j + w, dst) + n, h, w, _ = im.shape + for ind in my_range(n): + i, j, c_h, c_w = fast_crop.get_random_crop(h, w, scale, ratio) + fast_crop.resize_crop(im[ind], i, i + c_h, j, j + c_w, dst[ind]) return dst - + + random_resized_crop.is_parallel = True return random_resized_crop def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - assert previous_state.jit_mode - return replace(previous_state, shape=(self.size, self.size, 3)), AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8')) + return replace(previous_state, jit_mode=True, shape=(self.size, self.size, 3)), \ + AllocationQuery((self.size, self.size, 3), dtype=previous_state.dtype) diff --git a/ffcv/transforms/translate.py b/ffcv/transforms/translate.py index e53e157e..a40890b6 100644 --- a/ffcv/transforms/translate.py +++ b/ffcv/transforms/translate.py @@ -33,15 +33,11 @@ def generate_code(self) -> Callable: def translate(images, dst): n, h, w, _ = images.shape - # y_coords = randint(low=0, high=2 * pad + 1, size=(n,)) - # x_coords = randint(low=0, high=2 * pad + 1, size=(n,)) - # dst = fill - - dst[:, pad:pad+h, pad:pad+w] = images for i in my_range(n): + dst[i] = 0 + dst[i, pad:pad+h, pad:pad+w] = images[i] y_coord = randint(low=0, high=2 * pad + 1) x_coord = randint(low=0, high=2 * pad + 1) - # images[i] = dst[i, y_coords[i]:y_coords[i]+h, x_coords[i]:x_coords[i]+w] images[i] = dst[i, y_coord:y_coord+h, x_coord:x_coord+w] return images diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index f4a44530..d7a55995 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -29,7 +29,7 @@ ToTorchImage() ] -def run_test(length, pipeline, compile=False): +def run_test(length, pipeline, should_compile=False, aug_name=''): my_dataset = Subset(CIFAR10(root='/tmp', train=True, download=True), range(length)) with NamedTemporaryFile() as handle: @@ -42,7 +42,7 @@ def run_test(length, pipeline, compile=False): writer.from_indexed_dataset(my_dataset, chunksize=10) - Compiler.set_enabled(compile) + Compiler.set_enabled(should_compile) loader = Loader(name, batch_size=7, num_workers=2, pipelines={ 'image': pipeline, @@ -57,18 +57,16 @@ def run_test(length, pipeline, compile=False): tot_indices = 0 tot_images = 0 - for (images, labels), (original_images, original_labels) in zip(loader, unaugmented_loader): - print(images.shape, original_images.shape) + for it_num, ((images, labels), (original_images, original_labels)) in enumerate(zip(loader, unaugmented_loader)): tot_indices += labels.shape[0] tot_images += images.shape[0] for label, original_label in zip(labels, original_labels): assert_that(label).is_equal_to(original_label) - if SAVE_IMAGES: + if SAVE_IMAGES and it_num == 0: save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]), - os.path.join(IMAGES_TMP_PATH, str(uuid.uuid4()) + '.jpeg') - ) + os.path.join(IMAGES_TMP_PATH, aug_name + '-' + str(uuid.uuid4()) + '.jpeg')) assert_that(tot_indices).is_equal_to(len(my_dataset)) assert_that(tot_images).is_equal_to(len(my_dataset)) @@ -80,7 +78,7 @@ def test_cutout(): Cutout(8), ToTensor(), ToTorchImage() - ], comp) + ], comp, 'cutout') def test_flip(): @@ -90,7 +88,7 @@ def test_flip(): RandomHorizontalFlip(1.0), ToTensor(), ToTorchImage() - ], comp) + ], comp, 'flip') def test_module_wrapper(): @@ -100,7 +98,7 @@ def test_module_wrapper(): ToTensor(), ToTorchImage(), ModuleWrapper(tvt.Grayscale(3)), - ], comp) + ], comp, 'module') def test_mixup(): @@ -110,7 +108,7 @@ def test_mixup(): ImageMixup(.5, False), ToTensor(), ToTorchImage() - ], comp) + ], comp, 'mixup') def test_poison(): @@ -125,8 +123,7 @@ def test_poison(): Poison(mask, alpha, list(range(100))), ToTensor(), ToTorchImage() - ], comp) - + ], comp, 'poison') def test_random_resized_crop(): for comp in [True, False]: @@ -137,7 +134,7 @@ def test_random_resized_crop(): size=32), ToTensor(), ToTorchImage() - ], comp) + ], comp, 'rrc') def test_translate(): @@ -147,7 +144,7 @@ def test_translate(): RandomTranslate(padding=10), ToTensor(), ToTorchImage() - ], comp) + ], comp, 'translate') ## Torchvision Transforms @@ -157,7 +154,7 @@ def test_torchvision_greyscale(): ToTensor(), ToTorchImage(), tvt.Grayscale(3), - ]) + ], aug_name='tv_grey') def test_torchvision_centercrop_pad(): run_test(100, [ @@ -166,7 +163,7 @@ def test_torchvision_centercrop_pad(): ToTorchImage(), tvt.CenterCrop(10), tvt.Pad(11) - ]) + ], aug_name='tv_crop_pad') def test_torchvision_random_affine(): run_test(100, [ @@ -174,7 +171,7 @@ def test_torchvision_random_affine(): ToTensor(), ToTorchImage(), tvt.RandomAffine(25), - ]) + ], aug_name='tv_random_affine') def test_torchvision_random_crop(): run_test(100, [ @@ -183,7 +180,7 @@ def test_torchvision_random_crop(): ToTorchImage(), tvt.Pad(10), tvt.RandomCrop(size=32), - ]) + ], aug_name='tv_randcrop') def test_torchvision_color_jitter(): run_test(100, [ @@ -191,21 +188,4 @@ def test_torchvision_color_jitter(): ToTensor(), ToTorchImage(), tvt.ColorJitter(.5, .5, .5, .5), - ]) - - -if __name__ == '__main__': - # test_cutout() - test_flip() - # test_module_wrapper() - # test_mixup() - # test_poison() - # test_random_resized_crop() - # test_translate() - - ## Torchvision Transforms - # test_torchvision_greyscale() - # test_torchvision_centercrop_pad() - # test_torchvision_random_affine() - # test_torchvision_random_crop() - # test_torchvision_color_jitter() + ], aug_name='tv_colorjitter') \ No newline at end of file From 3487682ba1287ff512c4935e1509947e31267a8e Mon Sep 17 00:00:00 2001 From: Aleksey Date: Tue, 29 Mar 2022 15:10:32 +0400 Subject: [PATCH 2/4] Update rgb_image.py fix typo --- ffcv/fields/rgb_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index eb1a1511..b6420f11 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -97,7 +97,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca min_height = heights.min() min_width = widths.min() if min_width != max_width or max_height != min_height: - msg = """SimpleRGBImageDecoder ony supports constant image, + msg = """SimpleRGBImageDecoder only supports constant image, consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder instead.""" raise TypeError(msg) From e5e95fd7e21c71b754aa772957af9e84cc26bdcb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 1 Sep 2022 14:42:58 +0200 Subject: [PATCH 3/4] increment loading index regardless of CUDA --- ffcv/loader/epoch_iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcv/loader/epoch_iterator.py b/ffcv/loader/epoch_iterator.py index b54fa96b..658ff86c 100644 --- a/ffcv/loader/epoch_iterator.py +++ b/ffcv/loader/epoch_iterator.py @@ -98,7 +98,7 @@ def run(self): event = ch.cuda.Event() event.record(ch.cuda.default_stream()) events[just_finished_slot] = event - b_ix += 1 + b_ix += 1 except StopIteration: self.output_queue.put(None) From af1015aa90d84274c6e4ea2b24056012eeb73b46 Mon Sep 17 00:00:00 2001 From: Andrew Ilyas Date: Mon, 27 Feb 2023 09:36:13 -0500 Subject: [PATCH 4/4] Update main.yml --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bf17fde0..bf3b2076 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,7 @@ jobs: mv docs docs_src cd docs_src pip install -U sphinx karma-sphinx-theme - pip install -U numpy==1.20 numba tqdm + pip install -U numpy numba tqdm pip install --upgrade -U pygments make html cp -r _build/html ../docs