Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

[PROPOSAL] Switch to pytest style test classes, use plain asserts #4204

Merged
merged 10 commits into from
May 7, 2020
41 changes: 10 additions & 31 deletions allennlp/common/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import pathlib
import shutil
import tempfile
from typing import Any, Iterable
from unittest import TestCase, mock
from unittest import mock

import torch
import pytest

from allennlp.common.checks import log_pytorch_version_info

TEST_DIR = tempfile.mkdtemp(prefix="allennlp_tests")


class AllenNlpTestCase(TestCase):
class AllenNlpTestCase:
"""
A custom subclass of `unittest.TestCase` that disables some of the more verbose AllenNLP
logging and that creates and destroys a temp directory as a test fixture.
Expand All @@ -25,7 +25,7 @@ class AllenNlpTestCase(TestCase):
TESTS_ROOT = MODULE_ROOT / "tests"
FIXTURES_ROOT = TESTS_ROOT / "fixtures"

def setUp(self):
def setup_method(self):
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG
)
Expand Down Expand Up @@ -54,37 +54,16 @@ def _cleanup_archive_dir_without_logging(path: str):
)
self.mock_cleanup_archive_dir = self.patcher.start()

def tearDown(self):
def teardown_method(self):
shutil.rmtree(self.TEST_DIR)
self.patcher.stop()


def parametrize(arg_names: Iterable[str], arg_values: Iterable[Iterable[Any]]):
"""
Decorator to create parameterized tests.
_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

# Parameters

arg_names : `Iterable[str]`, required.
Argument names to pass to the test function.
arg_values : `Iterable[Iterable[Any]]`, required.
Iterable of values to pass to each of the args.
The decorated test will be run for each inner iterable.
def multi_device(test_method):
"""

def decorator(func):
def wrapper(*args, **kwargs):
for arg_value in arg_values:
kwargs_extra = {name: value for name, value in zip(arg_names, arg_value)}
func(*args, **kwargs, **kwargs_extra)

return wrapper

return decorator


_available_devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
multi_device = parametrize(("device",), [(device,) for device in _available_devices])
"""
Decorator that provides an argument `device` of type `str` for each available PyTorch device.
"""
Decorator that provides an argument `device` of type `str` for each available PyTorch device.
"""
return pytest.mark.parametrize("device", _available_devices)(pytest.mark.gpu(test_method))
10 changes: 3 additions & 7 deletions allennlp/tests/commands/docstring_help_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,16 @@ def test_docstring_help(self):
subcommand = match.group(2)
actual_output = _subcommand_help_output(subcommand)

self.assertEqual(
expected_output,
actual_output,
assert expected_output == actual_output, (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this give me those pretty comparisons, that show immediately where the difference is, even in large data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sir!

f"The documentation for the subcommand usage"
f" in the module {module_info.name}"
f" does not match the output of running"
f" `{str_call_subcommand_help}`."
f" Please update the docstring to match the"
f" output.",
f" output."
)
else:
self.assertIn(
module_info.name,
[parent_module.__name__ + ".subcommand"],
assert module_info.name in [parent_module.__name__ + ".subcommand"], (
f"The documentation for the subcommand usage was not found within the docstring of"
f" the module {module_info.name}",
)
9 changes: 5 additions & 4 deletions allennlp/tests/commands/evaluate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from flaky import flaky
import pytest

from allennlp.commands.evaluate import evaluate_from_args, Evaluate, evaluate
from allennlp.common.testing import AllenNlpTestCase
Expand Down Expand Up @@ -32,8 +33,8 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore


class TestEvaluate(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()

self.parser = argparse.ArgumentParser(description="Testing")
subparsers = self.parser.add_subparsers(title="Commands", metavar="")
Expand All @@ -44,7 +45,7 @@ def test_evaluate_calculates_average_loss(self):
outputs = [{"loss": torch.Tensor([loss])} for loss in losses]
data_loader = DummyDataLoader(outputs)
metrics = evaluate(DummyModel(), data_loader, -1, "")
self.assertAlmostEqual(metrics["loss"], 8.0)
assert metrics["loss"] == pytest.approx(8.0)

def test_evaluate_calculates_average_loss_with_weights(self):
losses = [7.0, 9.0, 8.0]
Expand All @@ -56,7 +57,7 @@ def test_evaluate_calculates_average_loss_with_weights(self):
]
data_loader = DummyDataLoader(outputs)
metrics = evaluate(DummyModel(), data_loader, -1, "batch_weight")
self.assertAlmostEqual(metrics["loss"], (70 + 18 + 12) / 13.5)
assert metrics["loss"] == pytest.approx((70 + 18 + 12) / 13.5)

@flaky
def test_evaluate_from_args(self):
Expand Down
12 changes: 6 additions & 6 deletions allennlp/tests/commands/find_learning_rate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def is_matplotlib_installed():


class TestFindLearningRate(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
self.params = lambda: Params(
{
"model": {
Expand Down Expand Up @@ -122,12 +122,12 @@ def test_find_learning_rate_args(self):
assert args.serialization_dir == "serialization_dir"

# config is required
with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
parser.parse_args(["find-lr", "-s", "serialization_dir"])
assert cm.exception.code == 2 # argparse code for incorrect usage

# serialization dir is required
with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
parser.parse_args(["find-lr", "path/to/params"])
assert cm.exception.code == 2 # argparse code for incorrect usage

Expand All @@ -154,8 +154,8 @@ def test_find_learning_rate_multi_gpu(self):


class TestSearchLearningRate(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
params = Params(
{
"model": {
Expand Down
8 changes: 4 additions & 4 deletions allennlp/tests/commands/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def test_fails_on_unknown_command(self):
"--silent",
]

with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
main()

assert cm.exception.code == 2 # argparse code for incorrect usage
assert cm.value.code == 2 # argparse code for incorrect usage

def test_subcommand_overrides(self):
called = False
Expand Down Expand Up @@ -134,9 +134,9 @@ def test_file_plugin_loaded(self):
sys.argv = ["allennlp"]

available_plugins = set(discover_plugins())
self.assertSetEqual(set(), available_plugins)
assert available_plugins == set()

with pushd(plugins_root):
main()
subcommands_available = Subcommand.list_available()
self.assertIn("d", subcommands_available)
assert "d" in subcommands_available
10 changes: 5 additions & 5 deletions allennlp/tests/commands/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@


class TestPredict(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
self.classifier_model_path = (
self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
)
Expand Down Expand Up @@ -215,7 +215,7 @@ class FakeDatasetReader(TextClassificationJsonReader):
"--predictor",
"test-predictor",
]
with self.assertRaises(NotImplementedError):
with pytest.raises(NotImplementedError):
main()

def test_base_predictor(self):
Expand Down Expand Up @@ -296,10 +296,10 @@ def test_fails_without_required_args(self):
"/path/to/archive",
] # executable # command # archive, but no input file

with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
main()

assert cm.exception.code == 2 # argparse code for incorrect usage
assert cm.value.code == 2 # argparse code for incorrect usage

def test_can_specify_predictor(self):
@Predictor.register("classification-explicit")
Expand Down
4 changes: 2 additions & 2 deletions allennlp/tests/commands/print_results_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class TestPrintResults(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()

self.out_dir1 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
self.out_dir2 = pathlib.Path(tempfile.mkdtemp(prefix="hi"))
Expand Down
8 changes: 4 additions & 4 deletions allennlp/tests/commands/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,12 @@ def test_train_args(self):
assert args.serialization_dir == "serialization_dir"

# config is required
with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
args = parser.parse_args(["train", "-s", "serialization_dir"])
assert cm.exception.code == 2 # argparse code for incorrect usage

# serialization dir is required
with self.assertRaises(SystemExit) as cm:
with pytest.raises(SystemExit) as cm:
args = parser.parse_args(["train", "path/to/params"])
assert cm.exception.code == 2 # argparse code for incorrect usage

Expand Down Expand Up @@ -534,8 +534,8 @@ def test_train_nograd_regex(self):


class TestDryRun(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()

self.params = Params(
{
Expand Down
4 changes: 2 additions & 2 deletions allennlp/tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def head_callback(_):


class TestFileUtils(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
self.glove_file = self.FIXTURES_ROOT / "embeddings/glove.6B.100d.sample.txt.gz"
with open(self.glove_file, "rb") as glove:
self.glove_bytes = glove.read()
Expand Down
4 changes: 2 additions & 2 deletions allennlp/tests/common/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class TestLogging(AllenNlpTestCase):
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
logger = logging.getLogger(str(random.random()))
self.test_log_file = os.path.join(self.TEST_DIR, "test.log")
logger.addHandler(logging.FileHandler(self.test_log_file))
Expand Down
12 changes: 6 additions & 6 deletions allennlp/tests/common/plugins_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@

class TestPlugins(AllenNlpTestCase):
@overrides
def setUp(self):
super().setUp()
def setup_method(self):
super().setup_method()
self.plugins_root = self.FIXTURES_ROOT / "plugins"

def test_no_plugins(self):
available_plugins = set(discover_plugins())
self.assertSetEqual(set(), available_plugins)
assert available_plugins == set()

def test_file_plugin(self):
available_plugins = set(discover_plugins())
self.assertSetEqual(set(), available_plugins)
assert available_plugins == set()

with pushd(self.plugins_root):
available_plugins = set(discover_plugins())
self.assertSetEqual({"d"}, available_plugins)
assert available_plugins == {"d"}

import_plugins()
subcommands_available = Subcommand.list_available()
self.assertIn("d", subcommands_available)
assert "d" in subcommands_available
18 changes: 7 additions & 11 deletions allennlp/tests/common/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

from allennlp.common.testing import AllenNlpTestCase, multi_device

actual_devices = set()

class TestTesting(AllenNlpTestCase):
def test_multi_device(self):
actual_devices = set()

@multi_device
def dummy_func(_self, device: str):
# Have `self` as in class test functions.
nonlocal actual_devices
actual_devices.add(device)

dummy_func(self)
class TestTesting(AllenNlpTestCase):
@multi_device
def test_multi_device(self, device: str):
actual_devices.add(device)

def test_devices_accounted_for(self):
expected_devices = {"cpu", "cuda"} if torch.cuda.is_available() else {"cpu"}
self.assertSetEqual(expected_devices, actual_devices)
assert expected_devices == actual_devices
11 changes: 6 additions & 5 deletions allennlp/tests/data/dataset_readers/dataset_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from allennlp.data.dataset_readers.dataset_reader import _LazyInstances


class TestDatasetReader:
cache_directory = str(AllenNlpTestCase.FIXTURES_ROOT / "data_cache" / "with_prefix")
class TestDatasetReader(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
self.cache_directory = str(AllenNlpTestCase.FIXTURES_ROOT / "data_cache" / "with_prefix")

@pytest.fixture(autouse=True)
def cache_directory_fixture(self):
yield self.cache_directory
def teardown_method(self):
super().teardown_method()
if os.path.exists(self.cache_directory):
shutil.rmtree(self.cache_directory)

Expand Down
Loading