Skip to content

Commit

Permalink
test_loggers_pickle_all
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Sep 26, 2024
1 parent 4e37bb8 commit 199fec4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/lightning/fabric/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")

_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

Expand Down
10 changes: 7 additions & 3 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import (
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")


def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
"""Verify that pickling trainer with logger works."""
_patch_comet_atexit(monkeypatch)

Expand All @@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)

with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with (
pytest.warns(FutureWarning, match="`weights_only=False`")
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
else nullcontext()
):
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

Expand Down

0 comments on commit 199fec4

Please sign in to comment.