From 10225c2dbc29c29b15af31f55e36bf32ab6cf155 Mon Sep 17 00:00:00 2001 From: Yuan-Chuan-YUE <69908243+Yuan-Chuan-YUE@users.noreply.github.com> Date: Sun, 22 Aug 2021 05:42:21 +0800 Subject: [PATCH] [CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly (#8756) * register tir.erf to lower opencl directly * add opencl codegen unit test * change erf opencl codegen unit test for checking there is erf in the source not erff --- src/target/source/intrin_rule_opencl.cc | 3 +++ .../unittest/test_target_codegen_opencl.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 288bb2cfc069..64a50c3c84b1 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round") TVM_REGISTER_OP("tir.exp").set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.erf").set_attr("opencl.FLowerIntrinsic", + DispatchPureExtern); + TVM_REGISTER_OP("tir.exp2") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index 98340f0e6ac5..56392ec8cccc 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -17,6 +17,7 @@ import tvm from tvm import te import tvm.testing +import re target = "opencl" @@ -120,6 +121,25 @@ def check_max(dev, n, dtype): check_max(dev, 1, "float64") +def test_opencl_erf(): + def check_erf(dev, n, dtype): + A = te.placeholder((n,), name="A", dtype=dtype) + C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C") + s = te.create_schedule(C.op) + s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) + fun = tvm.build(s, [A, C], target) + source_str = fun.imported_modules[0].get_source() + matches = re.findall("erf", source_str) + error_matches = re.findall("erff", source_str) + assert len(matches) == 1 and len(error_matches) == 0 + + dev = tvm.device(target, 0) + + check_erf(dev, 1, "float32") + check_erf(dev, 1, "float64") + + if __name__ == "__main__": test_opencl_ternary_expression() test_opencl_inf_nan() + test_opencl_erf()