Skip to content

Commit

Permalink
disable cuda int8 schedule for non-cuda gpu target (apache#9014)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and ylc committed Sep 29, 2021
1 parent 17114c9 commit 9496a19
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 8 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
if data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8"):
if (
target.kind.name == "cuda"
and data.dtype in ("int8", "uint8")
and kernel.dtype in ("int8", "uint8")
):
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
Expand Down Expand Up @@ -293,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
"Unsupported shape for conv2d HWNC.\
Need to satisfy tensor core schedule."
)
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
elif target.kind.name == "cuda" and layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
Expand Down Expand Up @@ -353,7 +357,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
ic_chunk = in_channels // 4

if (
data.dtype in ["int8", "uint8"]
target.kind.name == "cuda"
and data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
and out_channels % groups == 0
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,6 @@ def test_run(
kernel_size,
):
target = tvm.target.Target(target)
if target.kind.name == "vulkan" and dtype == "int8":
# The schedule selection incorrectly picks an
# implementation that requires NCHWc packed input.
pytest.xfail("Known failing test for vulkan")

x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
Expand Down

0 comments on commit 9496a19

Please sign in to comment.