Skip to content

Commit

Permalink
[UnitTests] Refactor the plugin-specific logic out into plugin.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Aug 20, 2021
1 parent de847fe commit 1609124
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 223 deletions.
4 changes: 1 addition & 3 deletions python/tvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from .utils import known_failing_targets, requires_cuda, requires_cudagraph
from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc
from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl
from .utils import _auto_parametrize_target, _count_num_fixture_uses
from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters
from .utils import _pytest_target_params, identity_after, terminate_self
from .utils import identity_after, terminate_self

from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count
from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback
Expand Down
221 changes: 215 additions & 6 deletions python/tvm/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@
"""

import collections

import pytest
import _pytest

import tvm.testing.utils
import tvm
from tvm.testing import utils


def pytest_configure(config):
Expand All @@ -51,21 +55,21 @@ def pytest_configure(config):
for markername, desc in markers.items():
config.addinivalue_line("markers", "{}: {}".format(markername, desc))

print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets())))
print("pytest marker:", config.option.markexpr)


def pytest_generate_tests(metafunc):
"""Called once per unit test, modifies/parametrizes it as needed."""
tvm.testing.utils._auto_parametrize_target(metafunc)
tvm.testing.utils._parametrize_correlated_parameters(metafunc)
_parametrize_correlated_parameters(metafunc)
_auto_parametrize_target(metafunc)


def pytest_collection_modifyitems(config, items):
"""Called after all tests are chosen, currently used for bookkeeping."""
# pylint: disable=unused-argument
tvm.testing.utils._count_num_fixture_uses(items)
tvm.testing.utils._remove_global_fixture_definitions(items)
_count_num_fixture_uses(items)
_remove_global_fixture_definitions(items)


@pytest.fixture
Expand All @@ -80,3 +84,208 @@ def pytest_sessionfinish(session, exitstatus):
if session.config.option.markexpr != "":
if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED:
session.exitstatus = pytest.ExitCode.OK


def _auto_parametrize_target(metafunc):
"""Automatically applies parametrize_targets
Used if a test function uses the "target" fixture, but isn't
already marked with @tvm.testing.parametrize_targets. Intended
for use in the pytest_generate_tests() handler of a conftest.py
file.
"""

def update_parametrize_target_arg(
argnames,
argvalues,
*args,
**kwargs,
):
args = [arg.strip() for arg in argnames.split(",") if arg.strip()]
if "target" in args:
target_i = args.index("target")

new_argvalues = []
for argvalue in argvalues:

if isinstance(argvalue, _pytest.mark.structures.ParameterSet):
# The parametrized value is already a
# pytest.param, so track any marks already
# defined.
param_set = argvalue.values
target = param_set[target_i]
additional_marks = argvalue.marks
elif len(args) == 1:
# Single value parametrization, argvalue is a list of values.
target = argvalue
param_set = (target,)
additional_marks = []
else:
# Multiple correlated parameters, argvalue is a list of tuple of values.
param_set = argvalue
target = param_set[target_i]
additional_marks = []

new_argvalues.append(
pytest.param(
*param_set, marks=_target_to_requirement(target) + additional_marks
)
)

try:
argvalues[:] = new_argvalues
except TypeError as e:
pyfunc = metafunc.definition.function
filename = pyfunc.__code__.co_filename
line_number = pyfunc.__code__.co_firstlineno
msg = (
f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) "
"is parametrized using a tuple of parameters instead of a list "
"of parameters."
)
raise TypeError(msg) from e

if "target" in metafunc.fixturenames:
# Update any explicit use of @pytest.mark.parmaetrize to
# parametrize over targets. This adds the appropriate
# @tvm.testing.requires_* markers for each target.
for mark in metafunc.definition.iter_markers("parametrize"):
update_parametrize_target_arg(*mark.args, **mark.kwargs)

# Check if any explicit parametrizations exist, and apply one
# if they do not. If the function is marked with either
# excluded or known failing targets, use these to determine
# the targets to be used.
parametrized_args = [
arg.strip()
for mark in metafunc.definition.iter_markers("parametrize")
for arg in mark.args[0].split(",")
]
if "target" not in parametrized_args:
excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", [])
xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", [])
metafunc.parametrize(
"target",
_pytest_target_params(None, excluded_targets, xfail_targets),
scope="session",
)


def _count_num_fixture_uses(items):
# Helper function, counts the number of tests that use each cached
# fixture. Should be called from pytest_collection_modifyitems().
for item in items:
is_skipped = item.get_closest_marker("skip") or any(
mark.args[0] for mark in item.iter_markers("skipif")
)
if is_skipped:
continue

for fixturedefs in item._fixtureinfo.name2fixturedefs.values():
# Only increment the active fixturedef, in a name has been overridden.
fixturedef = fixturedefs[-1]
if hasattr(fixturedef.func, "num_tests_use_this_fixture"):
fixturedef.func.num_tests_use_this_fixture[0] += 1


def _remove_global_fixture_definitions(items):
# Helper function, removes fixture definitions from the global
# variables of the modules they were defined in. This is intended
# to improve readability of error messages by giving a NameError
# if a test function accesses a pytest fixture but doesn't include
# it as an argument. Should be called from
# pytest_collection_modifyitems().

modules = set(item.module for item in items)

for module in modules:
for name in dir(module):
obj = getattr(module, name)
if hasattr(obj, "_pytestfixturefunction") and isinstance(
obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker
):
delattr(module, name)


def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None):
# Include unrunnable targets here. They get skipped by the
# pytest.mark.skipif in _target_to_requirement(), showing up as
# skipped tests instead of being hidden entirely.
if targets is None:
if excluded_targets is None:
excluded_targets = set()

if xfail_targets is None:
xfail_targets = set()

target_marks = []
for t in utils._get_targets():
# Excluded targets aren't included in the params at all.
if t["target_kind"] not in excluded_targets:

# Known failing targets are included, but are marked
# as expected to fail.
extra_marks = []
if t["target_kind"] in xfail_targets:
extra_marks.append(
pytest.mark.xfail(
reason='Known failing test for target "{}"'.format(t["target_kind"])
)
)

target_marks.append((t["target"], extra_marks))

else:
target_marks = [(target, []) for target in targets]

return [
pytest.param(target, marks=_target_to_requirement(target) + extra_marks)
for target, extra_marks in target_marks
]


def _target_to_requirement(target):
if isinstance(target, str):
target = tvm.target.Target(target)

# mapping from target to decorator
if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []):
return utils.requires_cudnn()
if target.kind.name == "cuda":
return utils.requires_cuda()
if target.kind.name == "rocm":
return utils.requires_rocm()
if target.kind.name == "vulkan":
return utils.requires_vulkan()
if target.kind.name == "nvptx":
return utils.requires_nvptx()
if target.kind.name == "metal":
return utils.requires_metal()
if target.kind.name == "opencl":
return utils.requires_opencl()
if target.kind.name == "llvm":
return utils.requires_llvm()
return []


def _parametrize_correlated_parameters(metafunc):
parametrize_needed = collections.defaultdict(list)

for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items():
fixturedef = fixturedefs[-1]
if hasattr(fixturedef.func, "parametrize_group") and hasattr(
fixturedef.func, "parametrize_values"
):
group = fixturedef.func.parametrize_group
values = fixturedef.func.parametrize_values
parametrize_needed[group].append((name, values))

for parametrize_group in parametrize_needed.values():
if len(parametrize_group) == 1:
name, values = parametrize_group[0]
metafunc.parametrize(name, values, indirect=True)
else:
names = ",".join(name for name, values in parametrize_group)
value_sets = zip(*[values for name, values in parametrize_group])
metafunc.parametrize(names, value_sets, indirect=True)
Loading

0 comments on commit 1609124

Please sign in to comment.