Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intrin] Adding a few missing math intrin #5011

Merged
merged 6 commits into from
Mar 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
} \

TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(exp2);
TVM_DECLARE_INTRIN_UNARY(exp10);
TVM_DECLARE_INTRIN_UNARY(erf);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(log2);
TVM_DECLARE_INTRIN_UNARY(log10);
TVM_DECLARE_INTRIN_UNARY(popcount);
TVM_DECLARE_INTRIN_UNARY(tan);
TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(atan);

namespace tir {
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum
Expand Down
98 changes: 98 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,38 @@ def exp(x):
return call_pure_intrin(x.dtype, "exp", x)


def exp2(x):
"""Calculate 2**x

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp2", x)


def exp10(x):
"""Calculate 10**x

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "exp10", x)


def erf(x):
"""Take gauss error function of the input x.

Expand Down Expand Up @@ -393,6 +425,38 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)


def log2(x):
"""Take log2 of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log2", x)


def log10(x):
"""Take log10 of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log10", x)

def tan(x):
"""Take tan of input x.

Expand Down Expand Up @@ -424,6 +488,23 @@ def cos(x):
"""
return call_pure_intrin(x.dtype, "cos", x)


def cosh(x):
"""Take cosh of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "cosh", x)


def sin(x):
"""Take sin of input x.

Expand All @@ -439,6 +520,23 @@ def sin(x):
"""
return call_pure_intrin(x.dtype, "sin", x)


def sinh(x):
"""Take sin of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "sinh", x)


def atan(x):
"""Take atan of input x.

Expand Down
59 changes: 59 additions & 0 deletions src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,35 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
PrimExpr ret = tir::CallNode::make(
x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic);
*rv = ret;
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);

Expand Down Expand Up @@ -108,9 +131,45 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx + exp_negx) / two;
*rv = ret;
});

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
using tir::make_const;
using tir::make_zero;
PrimExpr e = targs[0];
const tir::CallNode* call = e.as<tir::CallNode>();
CHECK(call != nullptr);
const PrimExpr& x = call->args[0];
PrimExpr two = make_const(x.dtype(), 2);
PrimExpr neg_one = make_const(x.dtype(), -1);
PrimExpr exp_negx = tir::CallNode::make(
x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic);
PrimExpr exp_posx = tir::CallNode::make(
x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic);
PrimExpr ret = (exp_posx - exp_negx) / two;
*rv = ret;
});

} // namespace llvm
} // namespace codegen
} // namespace tvm
Expand Down
18 changes: 18 additions & 0 deletions src/target/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf")
.set_body(DispatchExternLibDevice);

Expand All @@ -72,6 +78,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt")
.set_body(DispatchExternLibDevice);

Expand All @@ -87,9 +99,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan")
.set_body(DispatchExternLibDevice);

Expand Down
18 changes: 18 additions & 0 deletions src/target/llvm/intrin_rule_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf")
.set_body(DispatchExternOCML);

Expand All @@ -71,6 +77,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt")
.set_body(DispatchExternOCML);

Expand All @@ -86,9 +98,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh")
.set_body(DispatchExternOCML);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan")
.set_body(DispatchExternOCML);

Expand Down
18 changes: 18 additions & 0 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,39 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
.set_body(DispatchExtern<CUDAFastMathTan>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
Loading