Skip to content

Commit

Permalink
[TIR] Support fold constants in specialize process (apache#8803)
Browse files Browse the repository at this point in the history
* support fold constants in specialize

* replace Substitue() with VisitExpr() in specializer.
  • Loading branch information
wrongtest-intellif authored and shingjan committed Aug 23, 2021
1 parent e76ecd8 commit 6fd09ab
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
49 changes: 45 additions & 4 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include <functional>
Expand All @@ -45,6 +46,27 @@ inline bool IsParam(const PrimFunc& func, const Var& param) {

/**************** Specializer ****************/

// Try fold constants if op's child get specialized to constant.
#define DEFINE_SPECIALIZER_BINARY_OP_MUTATE(BinaryNode, BinaryFunc) \
PrimExpr VisitExpr_(const BinaryNode* op) final { \
PrimExpr a = VisitExpr(op->a); \
PrimExpr b = VisitExpr(op->b); \
if (a.same_as(op->a) && b.same_as(op->b)) { \
return GetRef<PrimExpr>(op); \
} else { \
return BinaryFunc(a, b); \
} \
}
#define DEFINE_SPECIALIZER_UNARY_OP_MUTATE(UnaryNode, UnaryFunc) \
PrimExpr VisitExpr_(const UnaryNode* op) final { \
PrimExpr a = VisitExpr(op->a); \
if (a.same_as(op->a)) { \
return GetRef<PrimExpr>(op); \
} else { \
return UnaryFunc(a); \
} \
}

/*! \brief Mutator to specialize function and remove const parameters */
class PrimFuncSpecializer : public StmtExprMutator {
public:
Expand Down Expand Up @@ -157,14 +179,33 @@ class PrimFuncSpecializer : public StmtExprMutator {
}
}

DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AddNode, add);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(SubNode, sub);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MulNode, mul);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(DivNode, div);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(ModNode, truncmod);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorDivNode, floordiv);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorModNode, floormod);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MaxNode, max);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MinNode, min);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(EQNode, equal);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(NENode, not_equal);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LTNode, less);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LENode, less_equal);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GTNode, greater);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GENode, greater_equal);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AndNode, logical_and);
DEFINE_SPECIALIZER_BINARY_OP_MUTATE(OrNode, logical_or);
DEFINE_SPECIALIZER_UNARY_OP_MUTATE(NotNode, logical_not);

private:
Buffer MutateBuffer(const Buffer& buffer) const {
Buffer MutateBuffer(const Buffer& buffer) {
Array<PrimExpr> shape =
MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
MutateArray(buffer->shape, [this](const PrimExpr& e) { return VisitExpr(e); });
Array<PrimExpr> strides =
MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
MutateArray(buffer->strides, [this](const PrimExpr& e) { return VisitExpr(e); });

PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);
PrimExpr elem_offset = VisitExpr(buffer->elem_offset);

if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
buffer->strides.same_as(strides)) {
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tir_specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty
B[vi, vj] = A[vi, vj]


@tvm.script.tir
def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None:
n = tir.var("int32")
A = tir.match_buffer(a, [n // 8, 8], "int32")
B = tir.match_buffer(b, [n], "int32")
with tir.block([n - 1], "") as [vi]:
B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42


@tvm.script.tir
def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None:
n = tir.var("int32")
A = tir.match_buffer(a, [2, 8], "int32")
B = tir.match_buffer(b, [16], "int32")
with tir.block([15], "") as [vi]:
B[vi] = A[vi // 8, vi % 8] + 714


def test_specialize_nothing():
func = matmul.specialize({})
assert func.same_as(matmul) # Pointer the same
Expand Down Expand Up @@ -191,9 +209,16 @@ def test_specialize_recursive_load():
pass


def test_specialize_with_const_folding():
b = param_in_arith_exprs.params[1]
func = param_in_arith_exprs.specialize({b: tir.decl_buffer([16])})
tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16)


if __name__ == "__main__":
test_specialize_nothing()
test_specialize_matmul()
test_specialize_elemwise()
test_specialize_mem_copy()
test_specialize_recursive_load()
test_specialize_with_const_folding()

0 comments on commit 6fd09ab

Please sign in to comment.