From e8e2303dcbc71b63e04b7dd21a1a83a956770871 Mon Sep 17 00:00:00 2001 From: Victor Turrisi Date: Tue, 25 Jan 2022 14:08:49 +0100 Subject: [PATCH] small compatibility fixes (#220) --- solo/args/utils.py | 2 -- tests/utils/test_checkpointer.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/solo/args/utils.py b/solo/args/utils.py index 055c828b..d4a15740 100644 --- a/solo/args/utils.py +++ b/solo/args/utils.py @@ -251,12 +251,10 @@ def additional_setup_linear(args: Namespace): # create backbone-specific arguments args.backbone_args = {"cifar": args.dataset in ["cifar10", "cifar100"]} - if "resnet" not in args.backbone: # dataset related for all transformers crop_size = args.crop_size[0] args.backbone_args["img_size"] = crop_size - if "vit" in args.backbone: args.backbone_args["patch_size"] = args.patch_size diff --git a/tests/utils/test_checkpointer.py b/tests/utils/test_checkpointer.py index 6eaeb604..d2fe574a 100644 --- a/tests/utils/test_checkpointer.py +++ b/tests/utils/test_checkpointer.py @@ -91,7 +91,8 @@ def test_checkpointer(): "optimizer_states", "lr_schedulers", ] - assert list(ckpt.keys()) == expected_keys + ckpt_keys = list(ckpt.keys()) + assert all(k in ckpt_keys for k in expected_keys) parser = argparse.ArgumentParser() ckpt_callback.add_checkpointer_args(parser)