From 6fd09ab6ab72ac1f880d39477a5a0ffb0e4b609d Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 22 Aug 2021 06:10:18 +0800 Subject: [PATCH] [TIR] Support fold constants in specialize process (#8803) * support fold constants in specialize * replace Substitue() with VisitExpr() in specializer. --- src/tir/ir/specialize.cc | 49 ++++++++++++++++++-- tests/python/unittest/test_tir_specialize.py | 25 ++++++++++ 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index aa5f271c20c2d..768787735a1fe 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -45,6 +46,27 @@ inline bool IsParam(const PrimFunc& func, const Var& param) { /**************** Specializer ****************/ +// Try fold constants if op's child get specialized to constant. +#define DEFINE_SPECIALIZER_BINARY_OP_MUTATE(BinaryNode, BinaryFunc) \ + PrimExpr VisitExpr_(const BinaryNode* op) final { \ + PrimExpr a = VisitExpr(op->a); \ + PrimExpr b = VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return BinaryFunc(a, b); \ + } \ + } +#define DEFINE_SPECIALIZER_UNARY_OP_MUTATE(UnaryNode, UnaryFunc) \ + PrimExpr VisitExpr_(const UnaryNode* op) final { \ + PrimExpr a = VisitExpr(op->a); \ + if (a.same_as(op->a)) { \ + return GetRef(op); \ + } else { \ + return UnaryFunc(a); \ + } \ + } + /*! \brief Mutator to specialize function and remove const parameters */ class PrimFuncSpecializer : public StmtExprMutator { public: @@ -157,14 +179,33 @@ class PrimFuncSpecializer : public StmtExprMutator { } } + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AddNode, add); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(SubNode, sub); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MulNode, mul); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(DivNode, div); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(ModNode, truncmod); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorDivNode, floordiv); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorModNode, floormod); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MaxNode, max); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MinNode, min); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(EQNode, equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(NENode, not_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LTNode, less); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LENode, less_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GTNode, greater); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GENode, greater_equal); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AndNode, logical_and); + DEFINE_SPECIALIZER_BINARY_OP_MUTATE(OrNode, logical_or); + DEFINE_SPECIALIZER_UNARY_OP_MUTATE(NotNode, logical_not); + private: - Buffer MutateBuffer(const Buffer& buffer) const { + Buffer MutateBuffer(const Buffer& buffer) { Array shape = - MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + MutateArray(buffer->shape, [this](const PrimExpr& e) { return VisitExpr(e); }); Array strides = - MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); }); + MutateArray(buffer->strides, [this](const PrimExpr& e) { return VisitExpr(e); }); - PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_); + PrimExpr elem_offset = VisitExpr(buffer->elem_offset); if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && buffer->strides.same_as(strides)) { diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index 2e9f1110732a2..d6cfadaf1fbcd 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -145,6 +145,24 @@ def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty B[vi, vj] = A[vi, vj] +@tvm.script.tir +def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, [n // 8, 8], "int32") + B = tir.match_buffer(b, [n], "int32") + with tir.block([n - 1], "") as [vi]: + B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 + + +@tvm.script.tir +def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None: + n = tir.var("int32") + A = tir.match_buffer(a, [2, 8], "int32") + B = tir.match_buffer(b, [16], "int32") + with tir.block([15], "") as [vi]: + B[vi] = A[vi // 8, vi % 8] + 714 + + def test_specialize_nothing(): func = matmul.specialize({}) assert func.same_as(matmul) # Pointer the same @@ -191,9 +209,16 @@ def test_specialize_recursive_load(): pass +def test_specialize_with_const_folding(): + b = param_in_arith_exprs.params[1] + func = param_in_arith_exprs.specialize({b: tir.decl_buffer([16])}) + tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16) + + if __name__ == "__main__": test_specialize_nothing() test_specialize_matmul() test_specialize_elemwise() test_specialize_mem_copy() test_specialize_recursive_load() + test_specialize_with_const_folding()