Skip to content

Commit

Permalink
[CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly (apache#8756)
Browse files Browse the repository at this point in the history
* register tir.erf to lower opencl directly

* add opencl codegen unit test

* change erf opencl codegen unit test for checking there is erf in the source not erff
  • Loading branch information
Yuan-Chuan-YUE authored and ylc committed Jan 13, 2022
1 parent 42aec85 commit 10225c2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round")
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);

TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);

Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_target_codegen_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
from tvm import te
import tvm.testing
import re

target = "opencl"

Expand Down Expand Up @@ -120,6 +121,25 @@ def check_max(dev, n, dtype):
check_max(dev, 1, "float64")


def test_opencl_erf():
def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A, C], target)
source_str = fun.imported_modules[0].get_source()
matches = re.findall("erf", source_str)
error_matches = re.findall("erff", source_str)
assert len(matches) == 1 and len(error_matches) == 0

dev = tvm.device(target, 0)

check_erf(dev, 1, "float32")
check_erf(dev, 1, "float64")


if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
test_opencl_erf()

0 comments on commit 10225c2

Please sign in to comment.