diff --git a/CHANGELOG.md b/CHANGELOG.md index acf6730854e..3aba4a6d638 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the potential for a race condition with `cached_path()` when extracting archives. Although the race condition is still possible if used with `force_extract=True`. - Fixed `wandb` callback to work in distributed training. +- Fixed `tqdm` logging into multiple files with `allennlp-optuna`. ## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22 diff --git a/allennlp/common/logging.py b/allennlp/common/logging.py index a278ec2edb4..23a337a0ce1 100644 --- a/allennlp/common/logging.py +++ b/allennlp/common/logging.py @@ -126,4 +126,5 @@ def excepthook(exctype, value, traceback): # also log tqdm from allennlp.common.tqdm import logger as tqdm_logger + tqdm_logger.handlers.clear() tqdm_logger.addHandler(file_handler) diff --git a/tests/common/logging_test.py b/tests/common/logging_test.py index 643eb6ccbdf..c072efb3ecb 100644 --- a/tests/common/logging_test.py +++ b/tests/common/logging_test.py @@ -2,8 +2,9 @@ import logging import random -from allennlp.common.logging import AllenNlpLogger +from allennlp.common.logging import AllenNlpLogger, prepare_global_logging from allennlp.common.testing import AllenNlpTestCase +from allennlp.common.tqdm import Tqdm class TestLogging(AllenNlpTestCase): @@ -64,3 +65,18 @@ def test_getLogger(self): logger = logging.getLogger("test_logger") assert isinstance(logger, AllenNlpLogger) + + def test_reset_tqdm_logger_handlers(self): + serialization_dir_a = os.path.join(self.TEST_DIR, "test_a") + os.makedirs(serialization_dir_a, exist_ok=True) + prepare_global_logging(serialization_dir_a) + serialization_dir_b = os.path.join(self.TEST_DIR, "test_b") + os.makedirs(serialization_dir_b, exist_ok=True) + prepare_global_logging(serialization_dir_b) + # Use range(1) to make sure there should be only 2 lines in the file (0% and 100%) + for _ in Tqdm.tqdm(range(1)): + pass + with open(os.path.join(serialization_dir_a, "out.log"), "r") as f: + assert len(f.readlines()) == 0 + with open(os.path.join(serialization_dir_b, "out.log"), "r") as f: + assert len(f.readlines()) == 2