From 4c52dee37d5e54463f8dc7afdc0b2db2d1f7127b Mon Sep 17 00:00:00 2001 From: Desroziers Date: Tue, 2 Feb 2021 09:15:10 +0100 Subject: [PATCH] fix for checkpoint --- ignite/handlers/checkpoint.py | 28 ++++++++++-------------- tests/ignite/handlers/test_checkpoint.py | 3 --- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index e95d5f45ddc..016bf943307 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -256,7 +256,7 @@ def score_function(engine): def __init__( self, - to_save: Optional[Mapping], + to_save: Mapping, save_handler: Union[Callable, BaseSaveHandler], filename_prefix: str = "", score_function: Optional[Callable] = None, @@ -268,23 +268,19 @@ def __init__( greater_or_equal: bool = False, ) -> None: - if to_save is not None: # for compatibility with ModelCheckpoint - if not isinstance(to_save, collections.Mapping): - raise TypeError(f"Argument `to_save` should be a dictionary, but given {type(to_save)}") + if not isinstance(to_save, collections.Mapping): + raise TypeError(f"Argument `to_save` should be a dictionary, but given {type(to_save)}") - if len(to_save) < 1: - raise ValueError("No objects to checkpoint.") - - self._check_objects(to_save, "state_dict") + self._check_objects(to_save, "state_dict") - if include_self: - if not isinstance(to_save, collections.MutableMapping): - raise TypeError( - f"If `include_self` is True, then `to_save` must be mutable, but given {type(to_save)}." - ) + if include_self: + if not isinstance(to_save, collections.MutableMapping): + raise TypeError( + f"If `include_self` is True, then `to_save` must be mutable, but given {type(to_save)}." + ) - if "checkpointer" in to_save: - raise ValueError(f"Cannot have key 'checkpointer' if `include_self` is True: {to_save}") + if "checkpointer" in to_save: + raise ValueError(f"Cannot have key 'checkpointer' if `include_self` is True: {to_save}") if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)): raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler") @@ -746,7 +742,7 @@ def __init__( disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs) super(ModelCheckpoint, self).__init__( - to_save=None, + to_save={}, save_handler=disk_saver, filename_prefix=filename_prefix, score_function=score_function, diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index e5cd023ed34..bd36f16b753 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -45,9 +45,6 @@ def test_checkpoint_wrong_input(): with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint([12], lambda x: x, "prefix") - with pytest.raises(ValueError, match=r"No objects to checkpoint."): - Checkpoint({}, lambda x: x, "prefix") - model = DummyModel() to_save = {"model": model}