Skip to content

Commit

Permalink
add more folding testcases and fix store fp32 folding result to double
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Aug 28, 2022
1 parent bcf0ac2 commit 9f85312
Show file tree
Hide file tree
Showing 3 changed files with 419 additions and 34 deletions.
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def truncmod(x, y, span):
return tvm.tir.truncmod(x, y, span)


@register
def truncdiv(x, y, span):
return tvm.tir.truncdiv(x, y, span)


@register
def ceildiv(x, y, span):
return tvm.tir.ceildiv(x, y, span)
Expand Down
46 changes: 32 additions & 14 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ inline bool IsIndexType(const DataType& type) {
}

/*! \brief Helper to get const folding result repr in int64. */
inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) {
inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) {
if (dtype.bits() < 64) {
x &= (1LL << dtype.bits()) - 1;
}
Expand All @@ -86,6 +86,20 @@ inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) {
return x;
}

/*! \brief Helper to get fp32 const folding result repr in double. */
inline double GetFoldResultDoubleRepr(float x) {
double res = static_cast<double>(x);
if (std::isinf(res) || std::isnan(res)) {
return res;
}
if (res < std::numeric_limits<float>::lowest()) {
return -std::numeric_limits<double>::infinity();
} else if (res > std::numeric_limits<float>::max()) {
return std::numeric_limits<double>::infinity();
}
return res;
}

#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
Expand All @@ -110,13 +124,14 @@ inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
int64_t res = pa->value + pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) + static_cast<float>(fb->value));
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) +
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
Expand All @@ -139,12 +154,13 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
int64_t res = pa->value - pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pb && pb->value == 0) return a;
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) - static_cast<float>(fb->value));
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) -
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
Expand All @@ -162,7 +178,7 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
int64_t res = pa->value * pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 1) return b;
Expand All @@ -174,7 +190,8 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
}
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) * static_cast<float>(fb->value));
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) *
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
Expand Down Expand Up @@ -202,7 +219,7 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
// NOTE: this will assumes truc div.
ICHECK_NE(pb->value, 0) << "Divide by zero";
int64_t res = pa->value / pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -213,7 +230,8 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
}
if (fa && fb && fb->value != 0) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) / static_cast<float>(fb->value));
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) /
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
Expand All @@ -236,7 +254,7 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
int64_t res = pa->value % pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -256,7 +274,7 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
int64_t res = arith::floordiv(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -267,8 +285,8 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
}
if (fa && fb && fb->value != 0) {
if (rtype.bits() == 32) {
return FloatImm(rtype,
std::floor(static_cast<float>(fa->value) / static_cast<float>(fb->value)));
return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast<float>(fa->value) /
static_cast<float>(fb->value))));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
Expand All @@ -291,7 +309,7 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
int64_t res = arith::floormod(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand Down
Loading

0 comments on commit 9f85312

Please sign in to comment.