From 4886cd70da6375ee8565b6bf722cf5a8be9cc8e2 Mon Sep 17 00:00:00 2001 From: Christian Gebbe <> Date: Mon, 12 Aug 2024 13:44:31 +0200 Subject: [PATCH 1/5] test: add failing test using two callbacks --- tests/tests_pytorch/loggers/test_wandb.py | 68 +++++++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9195f628348b..64ac916352a0b 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -374,6 +374,44 @@ def test_wandb_log_model(wandb_mock, tmp_path): ) wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"]) + # Test wandb artifact with two checkpoint_callbacks + wandb_mock.init().log_artifact.reset_mock() + wandb_mock.init.reset_mock() + wandb_mock.Artifact.reset_mock() + logger = WandbLogger(save_dir=tmp_path, log_model=True) + logger.experiment.id = "1" + logger.experiment.name = "run_name" + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=3, + limit_train_batches=3, + limit_val_batches=3, + callbacks=[ + ModelCheckpoint(monitor="epoch", save_top_k=2), + ModelCheckpoint(monitor="step", save_top_k=2), + ], + ) + trainer.fit(model) + for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]: + wandb_mock.Artifact.assert_any_call( + name="model-1", + type="model", + metadata={ + "score": val, + "original_filename": f"epoch=0-step=3-v{version}.ckpt", + "ModelCheckpoint": { + "monitor": name, + "mode": "min", + "save_last": None, + "save_top_k": 2, + "save_weights_only": False, + "_every_n_train_steps": 0, + }, + }, + ) + wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"]) + def test_wandb_log_model_with_score(wandb_mock, tmp_path): """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor.""" @@ -443,10 +481,12 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_image(key="samples", images=["1.jpg", "2.jpg"], step=5) wandb_mock.Image.assert_called_with("2.jpg") - wandb_mock.init().log.assert_called_once_with({ - "samples": [wandb_mock.Image(), wandb_mock.Image()], - "trainer/global_step": 5, - }) + wandb_mock.init().log.assert_called_once_with( + { + "samples": [wandb_mock.Image(), wandb_mock.Image()], + "trainer/global_step": 5, + } + ) # test log_image with captions wandb_mock.init().log.reset_mock() @@ -473,10 +513,12 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], step=5) wandb_mock.Audio.assert_called_with("2.mp3") - wandb_mock.init().log.assert_called_once_with({ - "samples": [wandb_mock.Audio(), wandb_mock.Audio()], - "trainer/global_step": 5, - }) + wandb_mock.init().log.assert_called_once_with( + { + "samples": [wandb_mock.Audio(), wandb_mock.Audio()], + "trainer/global_step": 5, + } + ) # test log_audio with captions wandb_mock.init().log.reset_mock() @@ -503,10 +545,12 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], step=5) wandb_mock.Video.assert_called_with("2.mp4") - wandb_mock.init().log.assert_called_once_with({ - "samples": [wandb_mock.Video(), wandb_mock.Video()], - "trainer/global_step": 5, - }) + wandb_mock.init().log.assert_called_once_with( + { + "samples": [wandb_mock.Video(), wandb_mock.Video()], + "trainer/global_step": 5, + } + ) # test log_video with captions wandb_mock.init().log.reset_mock() From 777ef1b28be451a24d628f07e97376605efaa779 Mon Sep 17 00:00:00 2001 From: Christian Gebbe <> Date: Mon, 12 Aug 2024 13:18:49 +0200 Subject: [PATCH 2/5] fix: save all checkpoint callbacks to wandb --- src/lightning/pytorch/loggers/wandb.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index c5d995bff35a5..ae2a103369591 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -321,7 +321,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._logged_model_time: Dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callbacks: Dict[str, ModelCheckpoint] = dict() # paths are processed as strings if save_dir is not None: @@ -587,7 +587,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback + self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback @staticmethod @rank_zero_only @@ -640,8 +640,9 @@ def finalize(self, status: str) -> None: # Currently, checkpoints only get logged on success return # log checkpoints as artifacts - if self._checkpoint_callback and self._experiment is not None: - self._scan_and_log_checkpoints(self._checkpoint_callback) + if self._experiment is not None: + for checkpoint_callback in self._checkpoint_callbacks.values(): + self._scan_and_log_checkpoints(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: import wandb From 9847e54952c58921d46574c6cde06d5041a19c13 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 12:06:54 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/loggers/test_wandb.py | 30 +++++++++-------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 64ac916352a0b..461e302a7ea9b 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -481,12 +481,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_image(key="samples", images=["1.jpg", "2.jpg"], step=5) wandb_mock.Image.assert_called_with("2.jpg") - wandb_mock.init().log.assert_called_once_with( - { - "samples": [wandb_mock.Image(), wandb_mock.Image()], - "trainer/global_step": 5, - } - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Image(), wandb_mock.Image()], + "trainer/global_step": 5, + }) # test log_image with captions wandb_mock.init().log.reset_mock() @@ -513,12 +511,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_audio(key="samples", audios=["1.mp3", "2.mp3"], step=5) wandb_mock.Audio.assert_called_with("2.mp3") - wandb_mock.init().log.assert_called_once_with( - { - "samples": [wandb_mock.Audio(), wandb_mock.Audio()], - "trainer/global_step": 5, - } - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Audio(), wandb_mock.Audio()], + "trainer/global_step": 5, + }) # test log_audio with captions wandb_mock.init().log.reset_mock() @@ -545,12 +541,10 @@ def test_wandb_log_media(wandb_mock, tmp_path): wandb_mock.init().log.reset_mock() logger.log_video(key="samples", videos=["1.mp4", "2.mp4"], step=5) wandb_mock.Video.assert_called_with("2.mp4") - wandb_mock.init().log.assert_called_once_with( - { - "samples": [wandb_mock.Video(), wandb_mock.Video()], - "trainer/global_step": 5, - } - ) + wandb_mock.init().log.assert_called_once_with({ + "samples": [wandb_mock.Video(), wandb_mock.Video()], + "trainer/global_step": 5, + }) # test log_video with captions wandb_mock.init().log.reset_mock() From 3a3609c1f857f292f7cccb7fd843cd083de20ba1 Mon Sep 17 00:00:00 2001 From: Christian Gebbe <> Date: Mon, 12 Aug 2024 14:12:07 +0200 Subject: [PATCH 4/5] chore: fix mypy --- src/lightning/pytorch/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index ae2a103369591..fb2d452c525e5 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -321,7 +321,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._logged_model_time: Dict[str, float] = {} - self._checkpoint_callbacks: Dict[str, ModelCheckpoint] = dict() + self._checkpoint_callbacks: Dict[int, ModelCheckpoint] = dict() # paths are processed as strings if save_dir is not None: From 2888744b05577ea3b31c4b210648800f3247c34b Mon Sep 17 00:00:00 2001 From: Christian Gebbe <> Date: Mon, 12 Aug 2024 14:21:39 +0200 Subject: [PATCH 5/5] chore: fix ruff --- src/lightning/pytorch/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index fb2d452c525e5..1e3ad994719ca 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -321,7 +321,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._logged_model_time: Dict[str, float] = {} - self._checkpoint_callbacks: Dict[int, ModelCheckpoint] = dict() + self._checkpoint_callbacks: Dict[int, ModelCheckpoint] = {} # paths are processed as strings if save_dir is not None: