Skip to content

Commit

Permalink
[ARITH] Improve vector simplification for float operands (apache#6043)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and trevor-m committed Sep 3, 2020
1 parent 3858b86 commit 857c0a0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>

#include <cmath>
#include <tuple>

#include "const_fold.h"
Expand Down Expand Up @@ -145,6 +146,14 @@ class PEqualChecker<IntImm> {
bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; }
};

template <>
class PEqualChecker<FloatImm> {
public:
bool operator()(const FloatImm& lhs, const FloatImm& rhs) const {
return std::fabs(lhs->value - rhs->value) < 1e-20;
}
};

template <>
class PEqualChecker<tir::Var> {
public:
Expand Down
6 changes: 6 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var match FloatImm
PVar<FloatImm> c4;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
Expand All @@ -133,6 +135,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes));
TVM_TRY_REWRITE_IF(x + broadcast(c4, lanes), x, c4.Eval()->value == 0.0f);
}

if (IsIndexType(op->dtype)) {
Expand Down Expand Up @@ -416,13 +419,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
// Pattern var match FloatImm
PVar<FloatImm> c3;
// Pattern var for lanes in broadcast and ramp
PVar<int> lanes;
// Vector rules
if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes));
TVM_TRY_REWRITE_IF(broadcast(c3, lanes) * x, broadcast(c3, lanes), c3.Eval()->value == 0.0f);
}

if (IsIndexType(op->dtype)) {
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_vector_simplify():
(y + x).astype("int32x2"))
ck.verify(tvm.tir.Broadcast(0, 4) + y,
tvm.tir.Broadcast(y, 4))
ck.verify(tvm.tir.Ramp(x, 1, 4).astype('float32x4') + tvm.tir.Broadcast(0.0, 4),
tvm.tir.Ramp(x, 1, 4).astype('float32x4'))
# Sub rules
ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4),
tvm.tir.Ramp(x - y, 2, 4))
Expand All @@ -59,6 +61,8 @@ def test_vector_simplify():
tvm.tir.Ramp(x * 2, 8, 4))
ck.verify(tvm.tir.Broadcast(0, 4) * x,
tvm.tir.Broadcast(0, 4))
ck.verify(tvm.tir.Broadcast(0.0, 4) * x,
tvm.tir.Broadcast(0.0, 4))

## DivMod rules
tdiv = tvm.tir.truncdiv
Expand Down

0 comments on commit 857c0a0

Please sign in to comment.