Skip to content

Commit

Permalink
Fix/augmentations (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi authored Oct 12, 2021
1 parent 09f5f89 commit d784b35
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 10 deletions.
2 changes: 1 addition & 1 deletion bash_files/pretrain/cifar/byol.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python3 ../../../main_pretrain.py \
--lr 1.0 \
--classifier_lr 0.1 \
--weight_decay 1e-5 \
--batch_size 128 \
--batch_size 256 \
--num_workers 4 \
--brightness 0.4 \
--contrast 0.4 \
Expand Down
4 changes: 4 additions & 0 deletions bash_files/pretrain/custom/byol.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ python3 ../../../main_pretrain.py \
--contrast 0.4 \
--saturation 0.2 \
--hue 0.1 \
--color_jitter_prob 0.8 \
--gray_scale_prob 0.2 \
--horizontal_flip_prob 0.5 \
--gaussian_prob 1.0 0.1 \
--solarization_prob 0.0 0.2 \
--num_crops_per_aug 1 1 \
--name byol-400ep-custom \
--entity unitn-mhug \
--project solo-learn \
Expand Down
9 changes: 8 additions & 1 deletion solo/args/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,20 @@ def augmentations_args(parser: ArgumentParser):
# cropping
parser.add_argument("--num_crops_per_aug", type=int, default=[2], nargs="+")

# augmentations
# color jitter
parser.add_argument("--brightness", type=float, required=True, nargs="+")
parser.add_argument("--contrast", type=float, required=True, nargs="+")
parser.add_argument("--saturation", type=float, required=True, nargs="+")
parser.add_argument("--hue", type=float, required=True, nargs="+")
parser.add_argument("--color_jitter_prob", type=float, default=[0.8], nargs="+")

# other augmentation probabilities
parser.add_argument("--gray_scale_prob", type=float, default=[0.2], nargs="+")
parser.add_argument("--horizontal_flip_prob", type=float, default=[0.5], nargs="+")
parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+")
parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+")

# cropping
parser.add_argument("--crop_size", type=int, default=[224], nargs="+")
parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+")
parser.add_argument("--max_scale", type=float, default=[1.0], nargs="+")
Expand Down
18 changes: 18 additions & 0 deletions solo/args/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def additional_setup_pretrain(args: Namespace):
args.contrast,
args.saturation,
args.hue,
args.color_jitter_prob,
args.gray_scale_prob,
args.horizontal_flip_prob,
args.gaussian_prob,
args.solarization_prob,
args.crop_size,
Expand All @@ -86,6 +89,9 @@ def additional_setup_pretrain(args: Namespace):
"contrast",
"saturation",
"hue",
"color_jitter_prob",
"gray_scale_prob",
"horizontal_flip_prob",
"gaussian_prob",
"solarization_prob",
"crop_size",
Expand All @@ -108,6 +114,9 @@ def additional_setup_pretrain(args: Namespace):
contrast=contrast,
saturation=saturation,
hue=hue,
color_jitter_prob=color_jitter_prob,
gray_scale_prob=gray_scale_prob,
horizontal_flip_prob=horizontal_flip_prob,
gaussian_prob=gaussian_prob,
solarization_prob=solarization_prob,
crop_size=crop_size,
Expand All @@ -119,6 +128,9 @@ def additional_setup_pretrain(args: Namespace):
contrast,
saturation,
hue,
color_jitter_prob,
gray_scale_prob,
horizontal_flip_prob,
gaussian_prob,
solarization_prob,
crop_size,
Expand All @@ -129,6 +141,9 @@ def additional_setup_pretrain(args: Namespace):
args.contrast,
args.saturation,
args.hue,
args.color_jitter_prob,
args.gray_scale_prob,
args.horizontal_flip_prob,
args.gaussian_prob,
args.solarization_prob,
args.crop_size,
Expand All @@ -153,6 +168,9 @@ def additional_setup_pretrain(args: Namespace):
contrast=args.contrast[0],
saturation=args.saturation[0],
hue=args.hue[0],
color_jitter_prob=args.color_jitter_prob[0],
gray_scale_prob=args.gray_scale_prob[0],
horizontal_flip_prob=args.horizontal_flip_prob[0],
gaussian_prob=args.gaussian_prob[0],
solarization_prob=args.solarization_prob[0],
crop_size=args.crop_size[0],
Expand Down
4 changes: 2 additions & 2 deletions solo/utils/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def on_train_start(self, trainer: pl.Trainer, _):
self.initial_setup(trainer)
self.save_args(trainer)

def on_validation_end(self, trainer: pl.Trainer, _):
"""Tries to save current checkpoint at the end of each validation epoch.
def on_train_epoch_end(self, trainer: pl.Trainer, _):
"""Tries to save current checkpoint at the end of each train epoch.
Args:
trainer (pl.Trainer): pytorch lightning trainer object.
Expand Down
4 changes: 2 additions & 2 deletions solo/utils/dali_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down Expand Up @@ -406,7 +406,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down
8 changes: 4 additions & 4 deletions solo/utils/pretrain_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down Expand Up @@ -357,7 +357,7 @@ def __init__(
saturation: float,
hue: float,
color_jitter_prob: float = 0.8,
gray_scale_prob: float = 0.8,
gray_scale_prob: float = 0.2,
horizontal_flip_prob: float = 0.5,
gaussian_prob: float = 0.5,
solarization_prob: float = 0.0,
Expand Down
3 changes: 3 additions & 0 deletions tests/args/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def test_argparse_augmentations():
assert "contrast" in actions
assert "saturation" in actions
assert "hue" in actions
assert "color_jitter_prob" in actions
assert "gray_scale_prob" in actions
assert "horizontal_flip_prob" in actions
assert "gaussian_prob" in actions
assert "solarization_prob" in actions
assert "min_scale" in actions
15 changes: 15 additions & 0 deletions tests/args/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def test_additional_setup_pretrain():
"contrast": [0.4],
"saturation": [0.2],
"hue": [0.1],
"color_jitter_prob": [0.8],
"gray_scale_prob": [0.2],
"horizontal_flip_prob": [0.5],
"gaussian_prob": [1.0, 0.1],
"solarization_prob": [0.2, 0.1],
"min_scale": [0.08],
Expand Down Expand Up @@ -234,6 +237,9 @@ def test_additional_setup_pretrain():
"contrast": [0.4],
"saturation": [0.2],
"hue": [0.1],
"color_jitter_prob": [0.8],
"gray_scale_prob": [0.2],
"horizontal_flip_prob": [0.5],
"gaussian_prob": [0.5],
"solarization_prob": [0.5],
"min_scale": [0.08],
Expand Down Expand Up @@ -264,6 +270,9 @@ def test_additional_setup_pretrain():
"contrast": [0.4],
"saturation": [0.2],
"hue": [0.1],
"color_jitter_prob": [0.8],
"gray_scale_prob": [0.2],
"horizontal_flip_prob": [0.5],
"gaussian_prob": [0.5],
"solarization_prob": [0.5],
"min_scale": [0.08],
Expand Down Expand Up @@ -294,6 +303,9 @@ def test_additional_setup_pretrain():
"contrast": [0.4],
"saturation": [0.2],
"hue": [0.1],
"color_jitter_prob": [0.8],
"gray_scale_prob": [0.2],
"horizontal_flip_prob": [0.5],
"gaussian_prob": [0.5, 0.2],
"solarization_prob": [0.5, 0.3],
"min_scale": [0.08],
Expand Down Expand Up @@ -330,6 +342,9 @@ def test_additional_setup_pretrain():
"contrast": [0.4],
"saturation": [0.2],
"hue": [0.1],
"color_jitter_prob": [0.8],
"gray_scale_prob": [0.2],
"horizontal_flip_prob": [0.5],
"gaussian_prob": [0.5, 0.2],
"solarization_prob": [0.5, 0.3],
"num_crops_per_aug": [1, 1],
Expand Down

0 comments on commit d784b35

Please sign in to comment.