From 389798283abcaefcaec57629a3b9a8a9449ea8c4 Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Mon, 20 Jun 2016 14:45:13 +0200 Subject: [PATCH] conftest files now use assertion rewriting Fix #1619 --- _pytest/assertion/__init__.py | 55 ++++++++++++++++++++++++--------- _pytest/assertion/rewrite.py | 57 +++++++++++++++++++++-------------- testing/test_assertrewrite.py | 34 +++++++++++++++++++++ testing/test_config.py | 11 +++++-- 4 files changed, 117 insertions(+), 40 deletions(-) diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 6921deb2a60..110e2eced23 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -4,6 +4,8 @@ import py import os import sys + +from _pytest.config import hookimpl from _pytest.monkeypatch import monkeypatch from _pytest.assertion import util @@ -42,9 +44,18 @@ def __init__(self, config, mode): self.trace = config.trace.root.get("assertion") -def pytest_configure(config): - mode = config.getvalue("assertmode") - if config.getvalue("noassert") or config.getvalue("nomagic"): +@hookimpl(tryfirst=True) +def pytest_load_initial_conftests(early_config, parser, args): +# def pytest_configure(config): + ns, ns_unknown_args = parser.parse_known_and_unknown_args(args) + mode = ns.assertmode + no_assert = ns.noassert + no_magic = ns.nomagic + # early_config = config + # no_assert = config.getvalue('noassert') + # no_magic = config.getvalue('nomagic') + # mode = config.getvalue('assertmode') + if no_assert or no_magic: mode = "plain" if mode == "rewrite": try: @@ -57,25 +68,30 @@ def pytest_configure(config): if (sys.platform.startswith('java') or sys.version_info[:3] == (2, 6, 0)): mode = "reinterp" + + early_config._assertstate = AssertionState(early_config, mode) + warn_about_missing_assertion(mode, early_config.pluginmanager) + if mode != "plain": _load_modules(mode) m = monkeypatch() - config._cleanup.append(m.undo) + early_config._cleanup.append(m.undo) m.setattr(py.builtin.builtins, 'AssertionError', reinterpret.AssertionError) # noqa + hook = None if mode == "rewrite": hook = rewrite.AssertionRewritingHook() # noqa + hook.set_config(early_config) sys.meta_path.insert(0, hook) - warn_about_missing_assertion(mode) - config._assertstate = AssertionState(config, mode) - config._assertstate.hook = hook - config._assertstate.trace("configured with mode set to %r" % (mode,)) + + early_config._assertstate.hook = hook + early_config._assertstate.trace("configured with mode set to %r" % (mode,)) def undo(): - hook = config._assertstate.hook + hook = early_config._assertstate.hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) - config.add_cleanup(undo) + early_config.add_cleanup(undo) def pytest_collection(session): @@ -154,7 +170,8 @@ def _load_modules(mode): from _pytest.assertion import rewrite # noqa -def warn_about_missing_assertion(mode): +def warn_about_missing_assertion(mode, pluginmanager): + print('got here') try: assert False except AssertionError: @@ -166,10 +183,18 @@ def warn_about_missing_assertion(mode): else: specifically = "failing tests may report as passing" - sys.stderr.write("WARNING: " + specifically + - " because assert statements are not executed " - "by the underlying Python interpreter " - "(are you using python -O?)\n") + # temporarily disable capture so we can print our warning + capman = pluginmanager.getplugin('capturemanager') + try: + out, err = capman.suspendcapture() + sys.stderr.write("WARNING: " + specifically + + " because assert statements are not executed " + "by the underlying Python interpreter " + "(are you using python -O?)\n") + finally: + capman.resumecapture() + sys.stdout.write(out) + sys.stderr.write(err) # Expose this plugin's implementation for the pytest_assertrepr_compare hook diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 14b8e49db2b..efddc89203e 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -50,14 +50,14 @@ def __init__(self): self._register_with_pkg_resources() def set_session(self, session): - self.fnpats = session.config.getini("python_files") self.session = session + def set_config(self, config): + self.config = config + self.fnpats = config.getini("python_files") + def find_module(self, name, path=None): - if self.session is None: - return None - sess = self.session - state = sess.config._assertstate + state = self.config._assertstate state.trace("find_module called for: %s" % name) names = name.rsplit(".", 1) lastname = names[-1] @@ -86,24 +86,11 @@ def find_module(self, name, path=None): return None else: fn = os.path.join(pth, name.rpartition(".")[2] + ".py") + fn_pypath = py.path.local(fn) - # Is this a test file? - if not sess.isinitpath(fn): - # We have to be very careful here because imports in this code can - # trigger a cycle. - self.session = None - try: - for pat in self.fnpats: - if fn_pypath.fnmatch(pat): - state.trace("matched test file %r" % (fn,)) - break - else: - return None - finally: - self.session = sess - else: - state.trace("matched test file (was specified on cmdline): %r" % - (fn,)) + if not self._should_rewrite(fn_pypath, state): + return + # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the # asserts, and load the rewritten source. We also cache the rewritten @@ -151,6 +138,32 @@ def find_module(self, name, path=None): self.modules[name] = co, pyc return self + def _should_rewrite(self, fn_pypath, state): + # always rewrite conftest files + fn = str(fn_pypath) + if fn_pypath.basename == 'conftest.py': + state.trace("rewriting conftest file: %r" % (fn,)) + return True + elif self.session is not None: + if self.session.isinitpath(fn): + state.trace("matched test file (was specified on cmdline): %r" % + (fn,)) + return True + else: + # modules not passed explicitly on the command line are only + # rewritten if they match the naming convention for test files + session = self.session # avoid a cycle here + self.session = None + try: + for pat in self.fnpats: + if fn_pypath.fnmatch(pat): + state.trace("matched test file %r" % (fn,)) + return True + finally: + self.session = session + del session + return False + def load_module(self, name): # If there is an existing module object named 'fullname' in # sys.modules, the loader must use that existing module. (Otherwise, diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index f43c424ca94..8d16bfc66d1 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -694,6 +694,40 @@ def test_foo(self): result = testdir.runpytest() result.stdout.fnmatch_lines('*1 passed*') + @pytest.mark.parametrize('initial_conftest', [True, False]) + @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) + def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode): + """Test that conftest files are using assertion rewrite on import. + (#1619) + """ + testdir.tmpdir.join('foo/tests').ensure(dir=1) + conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py' + contents = { + conftest_path: """ + import pytest + @pytest.fixture + def check_first(): + def check(values, value): + assert values.pop(0) == value + return check + """, + 'foo/tests/test_foo.py': """ + def test(check_first): + check_first([10, 30], 30) + """ + } + testdir.makepyfile(**contents) + result = testdir.runpytest_subprocess('--assert=%s' % mode) + if mode == 'plain': + expected = 'E AssertionError' + elif mode == 'rewrite': + expected = '*assert 10 == 30*' + elif mode == 'reinterp': + expected = '*AssertionError:*was re-run*' + else: + assert 0 + result.stdout.fnmatch_lines([expected]) + def test_issue731(testdir): testdir.makepyfile(""" diff --git a/testing/test_config.py b/testing/test_config.py index fe06540173e..69bea4a9c88 100644 --- a/testing/test_config.py +++ b/testing/test_config.py @@ -485,9 +485,14 @@ def pytest_load_initial_conftests(self): pm.register(m) hc = pm.hook.pytest_load_initial_conftests l = hc._nonwrappers + hc._wrappers - assert l[-1].function.__module__ == "_pytest.capture" - assert l[-2].function == m.pytest_load_initial_conftests - assert l[-3].function.__module__ == "_pytest.config" + expected = [ + "_pytest.config", + 'test_config', + '_pytest.assertion', + '_pytest.capture', + ] + assert [x.function.__module__ for x in l] == expected + class TestWarning: def test_warn_config(self, testdir):