diff --git a/conftest.py b/conftest.py index f591fe970de8..28859fd4a17b 100644 --- a/conftest.py +++ b/conftest.py @@ -14,36 +14,5 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest -from pytest import ExitCode -import tvm -import tvm.testing - - -def pytest_configure(config): - print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets()))) - print("pytest marker:", config.option.markexpr) - - -@pytest.fixture -def dev(target): - return tvm.device(target) - - -def pytest_generate_tests(metafunc): - tvm.testing._auto_parametrize_target(metafunc) - tvm.testing._parametrize_correlated_parameters(metafunc) - - -def pytest_collection_modifyitems(config, items): - tvm.testing._count_num_fixture_uses(items) - tvm.testing._remove_global_fixture_definitions(items) - - -def pytest_sessionfinish(session, exitstatus): - # Don't exit with an error if we select a subset of tests that doesn't - # include anything - if session.config.option.markexpr != "": - if exitstatus == ExitCode.NO_TESTS_COLLECTED: - session.exitstatus = ExitCode.OK +pytest_plugins = ["tvm.testing.plugin"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 675f8fe9b5a0..000000000000 --- a/pytest.ini +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -[pytest] -markers = - gpu: mark a test as requiring a gpu - tensorcore: mark a test as requiring a tensorcore - cuda: mark a test as requiring cuda - opencl: mark a test as requiring opencl - rocm: mark a test as requiring rocm - vulkan: mark a test as requiring vulkan - metal: mark a test as requiring metal - llvm: mark a test as requiring llvm diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index bd1ada4fa284..268c86e888e4 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + # pylint: disable=redefined-builtin, wildcard-import """Utility Python functions for TVM testing""" from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true @@ -23,9 +24,7 @@ from .utils import known_failing_targets, requires_cuda, requires_cudagraph from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl -from .utils import _auto_parametrize_target, _count_num_fixture_uses -from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters -from .utils import _pytest_target_params, identity_after, terminate_self +from .utils import identity_after, terminate_self from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py new file mode 100644 index 000000000000..06b4fa4f65eb --- /dev/null +++ b/python/tvm/testing/plugin.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pytest plugin for using tvm testing extensions. + +TVM provides utilities for testing across all supported targets, and +to more easily parametrize across many inputs. For more information +on usage of these features, see documentation in the tvm.testing +module. + +These are enabled by default in all pytests provided by tvm, but may +be useful externally for one-off testing. To enable, add the +following line to the test script, or to the conftest.py in the same +directory as the test scripts. + + pytest_plugins = ['tvm.testing.plugin'] + +""" + +import collections + +import pytest +import _pytest + +import tvm +from tvm.testing import utils + + +MARKERS = { + "gpu": "mark a test as requiring a gpu", + "tensorcore": "mark a test as requiring a tensorcore", + "cuda": "mark a test as requiring cuda", + "opencl": "mark a test as requiring opencl", + "rocm": "mark a test as requiring rocm", + "vulkan": "mark a test as requiring vulkan", + "metal": "mark a test as requiring metal", + "llvm": "mark a test as requiring llvm", +} + + +def pytest_configure(config): + """Runs at pytest configure time, defines marks to be used later.""" + + for markername, desc in MARKERS.items(): + config.addinivalue_line("markers", "{}: {}".format(markername, desc)) + + print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets()))) + print("pytest marker:", config.option.markexpr) + + +def pytest_generate_tests(metafunc): + """Called once per unit test, modifies/parametrizes it as needed.""" + _parametrize_correlated_parameters(metafunc) + _auto_parametrize_target(metafunc) + + +def pytest_collection_modifyitems(config, items): + """Called after all tests are chosen, currently used for bookkeeping.""" + # pylint: disable=unused-argument + _count_num_fixture_uses(items) + _remove_global_fixture_definitions(items) + + +@pytest.fixture +def dev(target): + """Give access to the device to tests that need it.""" + return tvm.device(target) + + +def pytest_sessionfinish(session, exitstatus): + # Don't exit with an error if we select a subset of tests that doesn't + # include anything + if session.config.option.markexpr != "": + if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED: + session.exitstatus = pytest.ExitCode.OK + + +def _auto_parametrize_target(metafunc): + """Automatically applies parametrize_targets + + Used if a test function uses the "target" fixture, but isn't + already marked with @tvm.testing.parametrize_targets. Intended + for use in the pytest_generate_tests() handler of a conftest.py + file. + + """ + + def update_parametrize_target_arg( + argnames, + argvalues, + *args, + **kwargs, + ): + args = [arg.strip() for arg in argnames.split(",") if arg.strip()] + if "target" in args: + target_i = args.index("target") + + new_argvalues = [] + for argvalue in argvalues: + + if isinstance(argvalue, _pytest.mark.structures.ParameterSet): + # The parametrized value is already a + # pytest.param, so track any marks already + # defined. + param_set = argvalue.values + target = param_set[target_i] + additional_marks = argvalue.marks + elif len(args) == 1: + # Single value parametrization, argvalue is a list of values. + target = argvalue + param_set = (target,) + additional_marks = [] + else: + # Multiple correlated parameters, argvalue is a list of tuple of values. + param_set = argvalue + target = param_set[target_i] + additional_marks = [] + + new_argvalues.append( + pytest.param( + *param_set, marks=_target_to_requirement(target) + additional_marks + ) + ) + + try: + argvalues[:] = new_argvalues + except TypeError as err: + pyfunc = metafunc.definition.function + filename = pyfunc.__code__.co_filename + line_number = pyfunc.__code__.co_firstlineno + msg = ( + f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) " + "is parametrized using a tuple of parameters instead of a list " + "of parameters." + ) + raise TypeError(msg) from err + + if "target" in metafunc.fixturenames: + # Update any explicit use of @pytest.mark.parmaetrize to + # parametrize over targets. This adds the appropriate + # @tvm.testing.requires_* markers for each target. + for mark in metafunc.definition.iter_markers("parametrize"): + update_parametrize_target_arg(*mark.args, **mark.kwargs) + + # Check if any explicit parametrizations exist, and apply one + # if they do not. If the function is marked with either + # excluded or known failing targets, use these to determine + # the targets to be used. + parametrized_args = [ + arg.strip() + for mark in metafunc.definition.iter_markers("parametrize") + for arg in mark.args[0].split(",") + ] + if "target" not in parametrized_args: + excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) + xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) + metafunc.parametrize( + "target", + _pytest_target_params(None, excluded_targets, xfail_targets), + scope="session", + ) + + +def _count_num_fixture_uses(items): + # Helper function, counts the number of tests that use each cached + # fixture. Should be called from pytest_collection_modifyitems(). + for item in items: + is_skipped = item.get_closest_marker("skip") or any( + mark.args[0] for mark in item.iter_markers("skipif") + ) + if is_skipped: + continue + + for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): + # Only increment the active fixturedef, in a name has been overridden. + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "num_tests_use_this_fixture"): + fixturedef.func.num_tests_use_this_fixture[0] += 1 + + +def _remove_global_fixture_definitions(items): + # Helper function, removes fixture definitions from the global + # variables of the modules they were defined in. This is intended + # to improve readability of error messages by giving a NameError + # if a test function accesses a pytest fixture but doesn't include + # it as an argument. Should be called from + # pytest_collection_modifyitems(). + + modules = set(item.module for item in items) + + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if hasattr(obj, "_pytestfixturefunction") and isinstance( + obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker + ): + delattr(module, name) + + +def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): + # Include unrunnable targets here. They get skipped by the + # pytest.mark.skipif in _target_to_requirement(), showing up as + # skipped tests instead of being hidden entirely. + if targets is None: + if excluded_targets is None: + excluded_targets = set() + + if xfail_targets is None: + xfail_targets = set() + + target_marks = [] + for t in utils._get_targets(): + # Excluded targets aren't included in the params at all. + if t["target_kind"] not in excluded_targets: + + # Known failing targets are included, but are marked + # as expected to fail. + extra_marks = [] + if t["target_kind"] in xfail_targets: + extra_marks.append( + pytest.mark.xfail( + reason='Known failing test for target "{}"'.format(t["target_kind"]) + ) + ) + + target_marks.append((t["target"], extra_marks)) + + else: + target_marks = [(target, []) for target in targets] + + return [ + pytest.param(target, marks=_target_to_requirement(target) + extra_marks) + for target, extra_marks in target_marks + ] + + +def _target_to_requirement(target): + if isinstance(target, str): + target = tvm.target.Target(target) + + # mapping from target to decorator + if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): + return utils.requires_cudnn() + if target.kind.name == "cuda": + return utils.requires_cuda() + if target.kind.name == "rocm": + return utils.requires_rocm() + if target.kind.name == "vulkan": + return utils.requires_vulkan() + if target.kind.name == "nvptx": + return utils.requires_nvptx() + if target.kind.name == "metal": + return utils.requires_metal() + if target.kind.name == "opencl": + return utils.requires_opencl() + if target.kind.name == "llvm": + return utils.requires_llvm() + return [] + + +def _parametrize_correlated_parameters(metafunc): + parametrize_needed = collections.defaultdict(list) + + for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): + fixturedef = fixturedefs[-1] + if hasattr(fixturedef.func, "parametrize_group") and hasattr( + fixturedef.func, "parametrize_values" + ): + group = fixturedef.func.parametrize_group + values = fixturedef.func.parametrize_values + parametrize_needed[group].append((name, values)) + + for parametrize_group in parametrize_needed.values(): + if len(parametrize_group) == 1: + name, values = parametrize_group[0] + metafunc.parametrize(name, values, indirect=True) + else: + names = ",".join(name for name, values in parametrize_group) + value_sets = zip(*[values for name, values in parametrize_group]) + metafunc.parametrize(names, value_sets, indirect=True) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 04a235b64fdf..6f115f8da58c 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -16,7 +16,14 @@ # under the License. # pylint: disable=invalid-name,unnecessary-comprehension -""" TVM testing utilities +"""TVM testing utilities + +Organization +************ + +This file contains functions expected to be called directly by a user +while writing unit tests. Integrations with the pytest framework +are in plugin.py. Testing Markers *************** @@ -53,8 +60,8 @@ def test_something(): fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new function in this module. Then targets using this node should be added to the `TVM_TEST_TARGETS` environment variable in the CI. + """ -import collections import copy import copyreg import ctypes @@ -65,7 +72,6 @@ def test_something(): import time import pickle import pytest -import _pytest import numpy as np import tvm import tvm.arith @@ -768,153 +774,6 @@ def requires_rpc(*args): return _compose(args, _requires_rpc) -def _target_to_requirement(target): - if isinstance(target, str): - target = tvm.target.Target(target) - - # mapping from target to decorator - if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): - return requires_cudnn() - if target.kind.name == "cuda": - return requires_cuda() - if target.kind.name == "rocm": - return requires_rocm() - if target.kind.name == "vulkan": - return requires_vulkan() - if target.kind.name == "nvptx": - return requires_nvptx() - if target.kind.name == "metal": - return requires_metal() - if target.kind.name == "opencl": - return requires_opencl() - if target.kind.name == "llvm": - return requires_llvm() - return [] - - -def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None): - # Include unrunnable targets here. They get skipped by the - # pytest.mark.skipif in _target_to_requirement(), showing up as - # skipped tests instead of being hidden entirely. - if targets is None: - if excluded_targets is None: - excluded_targets = set() - - if xfail_targets is None: - xfail_targets = set() - - target_marks = [] - for t in _get_targets(): - # Excluded targets aren't included in the params at all. - if t["target_kind"] not in excluded_targets: - - # Known failing targets are included, but are marked - # as expected to fail. - extra_marks = [] - if t["target_kind"] in xfail_targets: - extra_marks.append( - pytest.mark.xfail( - reason='Known failing test for target "{}"'.format(t["target_kind"]) - ) - ) - - target_marks.append((t["target"], extra_marks)) - - else: - target_marks = [(target, []) for target in targets] - - return [ - pytest.param(target, marks=_target_to_requirement(target) + extra_marks) - for target, extra_marks in target_marks - ] - - -def _auto_parametrize_target(metafunc): - """Automatically applies parametrize_targets - - Used if a test function uses the "target" fixture, but isn't - already marked with @tvm.testing.parametrize_targets. Intended - for use in the pytest_generate_tests() handler of a conftest.py - file. - - """ - - def update_parametrize_target_arg( - argnames, - argvalues, - *args, - **kwargs, - ): - args = [arg.strip() for arg in argnames.split(",") if arg.strip()] - if "target" in args: - target_i = args.index("target") - - new_argvalues = [] - for argvalue in argvalues: - - if isinstance(argvalue, _pytest.mark.structures.ParameterSet): - # The parametrized value is already a - # pytest.param, so track any marks already - # defined. - param_set = argvalue.values - target = param_set[target_i] - additional_marks = argvalue.marks - elif len(args) == 1: - # Single value parametrization, argvalue is a list of values. - target = argvalue - param_set = (target,) - additional_marks = [] - else: - # Multiple correlated parameters, argvalue is a list of tuple of values. - param_set = argvalue - target = param_set[target_i] - additional_marks = [] - - new_argvalues.append( - pytest.param( - *param_set, marks=_target_to_requirement(target) + additional_marks - ) - ) - - try: - argvalues[:] = new_argvalues - except TypeError as e: - pyfunc = metafunc.definition.function - filename = pyfunc.__code__.co_filename - line_number = pyfunc.__code__.co_firstlineno - msg = ( - f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) " - "is parametrized using a tuple of parameters instead of a list " - "of parameters." - ) - raise TypeError(msg) from e - - if "target" in metafunc.fixturenames: - # Update any explicit use of @pytest.mark.parmaetrize to - # parametrize over targets. This adds the appropriate - # @tvm.testing.requires_* markers for each target. - for mark in metafunc.definition.iter_markers("parametrize"): - update_parametrize_target_arg(*mark.args, **mark.kwargs) - - # Check if any explicit parametrizations exist, and apply one - # if they do not. If the function is marked with either - # excluded or known failing targets, use these to determine - # the targets to be used. - parametrized_args = [ - arg.strip() - for mark in metafunc.definition.iter_markers("parametrize") - for arg in mark.args[0].split(",") - ] - if "target" not in parametrized_args: - excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", []) - xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", []) - metafunc.parametrize( - "target", - _pytest_target_params(None, excluded_targets, xfail_targets), - scope="session", - ) - - def parametrize_targets(*args): """Parametrize a test over a specific set of targets. @@ -1164,28 +1023,6 @@ def fixture_func(*_cls, request): return outputs -def _parametrize_correlated_parameters(metafunc): - parametrize_needed = collections.defaultdict(list) - - for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items(): - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "parametrize_group") and hasattr( - fixturedef.func, "parametrize_values" - ): - group = fixturedef.func.parametrize_group - values = fixturedef.func.parametrize_values - parametrize_needed[group].append((name, values)) - - for parametrize_group in parametrize_needed.values(): - if len(parametrize_group) == 1: - name, values = parametrize_group[0] - metafunc.parametrize(name, values, indirect=True) - else: - names = ",".join(name for name, values in parametrize_group) - value_sets = zip(*[values for name, values in parametrize_group]) - metafunc.parametrize(names, value_sets, indirect=True) - - def fixture(func=None, *, cache_return_value=False): """Convenience function to define pytest fixtures. @@ -1319,7 +1156,9 @@ def _fixture_cache(func): # Can't use += on a bound method's property. Therefore, this is a # list rather than a variable so that it can be accessed from the # pytest_collection_modifyitems(). - num_uses_remaining = [0] + num_tests_use_this_fixture = [0] + + num_times_fixture_used = 0 # Using functools.lru_cache would require the function arguments # to be hashable, which wouldn't allow caching fixtures that @@ -1344,6 +1183,14 @@ def get_cache_key(*args, **kwargs): @functools.wraps(func) def wrapper(*args, **kwargs): + if num_tests_use_this_fixture[0] == 0: + raise RuntimeError( + "Fixture use count is 0. " + "This can occur if tvm.testing.plugin isn't registered. " + "If using outside of the TVM test directory, " + "please add `pytest_plugins = ['tvm.testing.plugin']` to your conftest.py" + ) + try: cache_key = get_cache_key(*args, **kwargs) @@ -1364,52 +1211,17 @@ def wrapper(*args, **kwargs): finally: # Clear the cache once all tests that use a particular fixture # have completed. - num_uses_remaining[0] -= 1 - if not num_uses_remaining[0]: + nonlocal num_times_fixture_used + num_times_fixture_used += 1 + if num_times_fixture_used >= num_tests_use_this_fixture[0]: cache.clear() - # Set in the pytest_collection_modifyitems() - wrapper.num_uses_remaining = num_uses_remaining + # Set in the pytest_collection_modifyitems(), by _count_num_fixture_uses + wrapper.num_tests_use_this_fixture = num_tests_use_this_fixture return wrapper -def _count_num_fixture_uses(items): - # Helper function, counts the number of tests that use each cached - # fixture. Should be called from pytest_collection_modifyitems(). - for item in items: - is_skipped = item.get_closest_marker("skip") or any( - mark.args[0] for mark in item.iter_markers("skipif") - ) - if is_skipped: - continue - - for fixturedefs in item._fixtureinfo.name2fixturedefs.values(): - # Only increment the active fixturedef, in a name has been overridden. - fixturedef = fixturedefs[-1] - if hasattr(fixturedef.func, "num_uses_remaining"): - fixturedef.func.num_uses_remaining[0] += 1 - - -def _remove_global_fixture_definitions(items): - # Helper function, removes fixture definitions from the global - # variables of the modules they were defined in. This is intended - # to improve readability of error messages by giving a NameError - # if a test function accesses a pytest fixture but doesn't include - # it as an argument. Should be called from - # pytest_collection_modifyitems(). - - modules = set(item.module for item in items) - - for module in modules: - for name in dir(module): - obj = getattr(module, name) - if hasattr(obj, "_pytestfixturefunction") and isinstance( - obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker - ): - delattr(module, name) - - def identity_after(x, sleep): """Testing function to return identity after sleep diff --git a/tests/python/unittest/test_tvm_testing_features.py b/tests/python/unittest/test_tvm_testing_features.py index 8885f55bbf4b..4c9c5d91901a 100644 --- a/tests/python/unittest/test_tvm_testing_features.py +++ b/tests/python/unittest/test_tvm_testing_features.py @@ -199,7 +199,7 @@ def test_num_uses_cached(self): class TestAutomaticMarks: @staticmethod def check_marks(request, target): - parameter = tvm.testing._pytest_target_params([target])[0] + parameter = tvm.testing.plugin._pytest_target_params([target])[0] required_marks = [decorator.mark for decorator in parameter.marks] applied_marks = list(request.node.iter_markers()) @@ -239,6 +239,11 @@ def uncacheable_fixture(self): return self.EmptyClass() def test_uses_uncacheable(self, request): + # Normally the num_tests_use_this_fixture would be set before + # anything runs. For this test case only, because we are + # delaying the use of the fixture, we need to manually + # increment it. + self.uncacheable_fixture.num_tests_use_this_fixture[0] += 1 with pytest.raises(TypeError): request.getfixturevalue("uncacheable_fixture")