Skip to content

Commit

Permalink
[ONNX] Enable GPU in ONNX importer tests (apache#7438)
Browse files Browse the repository at this point in the history
* remove hardcoded target and ctx

* fix c-codgen for floating point mod

* MDisable onnx gpu test for argmin / argmax so we can get this fix merged. Matt or myself will fix later but we don't have time right now.

* lint

* fix black

* Add flag to task_python_frontend.sh to only run GPU enabled tests on GPU

* black again

* Enable GPU for test_nonzero

* Respond to comments

* Don't test batch matmul on CUDA

* Turn cuda off for dynamic batch matmul test

* Fix task script

* Flaky test

* another flaky test

Co-authored-by: mbrookhart <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent e94b6e7 commit f56c352
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
16 changes: 15 additions & 1 deletion src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,21 @@ void CodeGenC::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenC::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "%", os, this);
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr(op, "%", os, this);
} else {
ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
<< op->dtype;
if (op->dtype.bits() == 32) {
PrintBinaryExpr(op, "fmodf", os, this);
} else if (op->dtype.bits() == 64) {
PrintBinaryExpr(op, "fmod", os, this);
} else {
ICHECK(false)
<< "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
<< op->dtype.bits() << " bits.";
}
}
}
void CodeGenC::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "min", os, this);
Expand Down
38 changes: 14 additions & 24 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def test_double_reshape():
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_expand():
def _test_expand(name, data, shape, ref_data, dtype="int32"):
shape_array = np.array(shape)
Expand Down Expand Up @@ -757,8 +756,7 @@ def add_noop_to_input_attr(attr_name, attr):
verify_with_ort_with_inputs(model, [indata], opset=10, freeze_params=True, use_vm=True)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
Expand Down Expand Up @@ -978,8 +976,7 @@ def test_gather_nd():
verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2])


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_onehot():
indices_shape = [10]
indices_array = np.random.randint(low=0, high=9, size=indices_shape, dtype="int32")
Expand Down Expand Up @@ -1091,7 +1088,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, target, dev):
verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, targets=[target])


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul(target, dev):
verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), (2, 3, 4, 4), target, dev)
Expand Down Expand Up @@ -1146,7 +1143,7 @@ def verify_model(ex, a_shape, b_shape):
verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape])


# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
# TODO(mbrookhart, electriclilies): Add CUDA as a target once batch matmul is fixed
@tvm.testing.parametrize_targets("llvm")
def test_batch_matmul_dynamic_model(target, dev):
verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, dev)
Expand Down Expand Up @@ -1319,8 +1316,7 @@ def verify_upsample3d_trilinear():
tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_upsample():
verify_upsample_nearest()
verify_upsample_bilinear()
Expand Down Expand Up @@ -1497,7 +1493,8 @@ def verify_argreduce(input_dim, op_name, axis=None, keepdims=None):
verify_with_ort_with_inputs(model, [a_np1])


@tvm.testing.uses_gpu
# TODO (mbrookhart, electriclilies) Fix argmin on GPU and enable this test
# @tvm.testing.uses_gpu
def test_forward_arg_min_max():
"""Verify argmin and argmax"""
verify_argreduce([3, 4, 4], "ArgMin")
Expand Down Expand Up @@ -1540,8 +1537,7 @@ def verify_constantofshape(input_dim, value, dtype):
verify_with_ort_with_inputs(model, [input_np], use_vm=True)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_constantofshape():
verify_constantofshape((2, 3, 4, 5), 10, "float32")
verify_constantofshape((3, 3), 0, "int32")
Expand Down Expand Up @@ -1627,8 +1623,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_pad():
verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], "constant", 0.0)
verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], "constant", 0.0)
Expand Down Expand Up @@ -2120,8 +2115,7 @@ def verify_tile_v6(indata, repeats, outdata):
verify_with_ort_with_inputs(model, [indata, repeats], use_vm=True, opset=6)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_tile():
x = np.random.rand(2, 3, 4, 5).astype(np.float32)
repeats = np.random.randint(low=1, high=10, size=(np.ndim(x),)).astype(np.int64)
Expand Down Expand Up @@ -2293,8 +2287,7 @@ def verify_batch_norm(in_shape):
verify_batch_norm([16, 16, 10, 10])


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_batch_norm_dynamic_subgraph():
def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):

Expand Down Expand Up @@ -3312,8 +3305,7 @@ def test_gru():
)


# TODO(mbrookhart): enable once VM supports heterogenous execution
# @tvm.testing.uses_gpu
@tvm.testing.uses_gpu
def test_resize():
def verify(ishape, oshape, scales, mode, coord_trans):
nodes = [
Expand Down Expand Up @@ -3420,9 +3412,7 @@ def verify_nonzero(indata, outdata, dtype):

model = helper.make_model(graph, producer_name="nonzero_test")

verify_with_ort_with_inputs(
model, [indata], targets=["llvm"], dtype="int64", use_vm=True, opset=9
)
verify_with_ort_with_inputs(model, [indata], dtype="int64", use_vm=True, opset=9)

input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 1], [0, 0, 1]]
Expand Down
5 changes: 4 additions & 1 deletion tests/scripts/task_python_frontend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ echo "Running relay MXNet frontend test..."
run_pytest cython python-frontend-mxnet tests/python/frontend/mxnet

echo "Running relay ONNX frontend test..."
run_pytest cython python-frontend-onnx tests/python/frontend/onnx
# Enable tvm.testing decorators in the ONNX importer test (not enabling in the other tests because we
# they do not consistently use the decorators to indicate that tests should run on GPU)
# In the future, we should enable tvm.testing decorators for all the test files.
PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" run_pytest cython python-frontend-onnx tests/python/frontend/onnx

echo "Running relay CoreML frontend test..."
run_pytest cython python-frontend-coreml tests/python/frontend/coreml
Expand Down

0 comments on commit f56c352

Please sign in to comment.