Skip to content

Commit

Permalink
[Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff (
Browse files Browse the repository at this point in the history
…apache#6078)

* [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Co-authored-by: Sergei Grechanik <[email protected]>

* fix lint

* fix clang-format

* add comments and magic number

* clang-lint

* address some comments

* remove FreeVarsVisitor

* fix constexpr lint

* fix lint

* fix lint

* add Map.Merge

* lint

* change Array::Concat & Map::Merge to global functions

* fix lint

* move functions to global

* static -> inline

Co-authored-by: Sergei Grechanik <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Aug 26, 2020
1 parent ccf9f6d commit cd3485f
Show file tree
Hide file tree
Showing 16 changed files with 1,767 additions and 59 deletions.
7 changes: 7 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class Analyzer;

using tir::Var;

enum DivMode {
/*! \brief Truncated division. */
kTruncDiv,
/*! \brief Floor division. */
kFloorDiv
};

/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
Expand Down
24 changes: 24 additions & 0 deletions include/tvm/arith/int_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ using tir::IterVar;
using tir::Var;
using tir::VarNode;

// According to experiments two best simplifications orders were can->rw and rw->can->rw,
// but rw->can->rw is better for a couple of cases.
// Also we should end with rw because it factors multipliers out.
constexpr int kSimplifyRewriteCanonicalRewrite = 3;

/*!
* \brief Represent integer grouped bounds which are classified into
* lower bounds (inclusive), upper bounds (inclusive) and equalities.
Expand Down Expand Up @@ -251,6 +256,15 @@ class IntConstraintsTransform : public ObjectRef {
TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst,
Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);

/*!
* \brief Chain-compose two IntConstraintsTransform together.
* this->dst must be the same as other->src.
* @param other another IntConstraintsTransform whose src is same as this->dst.
* @return composed IntConstraintsTransform(this->src, other->dst)
* with its variables and ranges are properly modified.
*/
IntConstraintsTransform operator+(const IntConstraintsTransform& other) const;

TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
};

Expand Down Expand Up @@ -306,6 +320,16 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol
*/
PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);

/*!
* \brief Combine the information into an array of (in)equalities.
* \param variables The variables in \p bounds.
* It is used to determine the iteration order to avoid indeterministic results.
* \param bounds grouped boundary of the variables.
* \param relations other relations.
*/
Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
const Array<PrimExpr>& relations);

/*!
* \brief Solve linear inequalities and infer the range of each variable.
* \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,22 @@ class Map : public ObjectRef {
MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
};

/*!
* \brief Merge two Maps.
* \param lhs the first Map to merge.
* \param rhs the second Map to merge.
* @return The merged Array. Original Maps are kept unchanged.
*/
template <typename K, typename V,
typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
inline Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {
for (const auto& p : rhs) {
lhs.Set(p.first, p.second);
}
return std::move(lhs);
}

} // namespace tvm

namespace tvm {
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,21 @@ class Array : public ObjectRef {
}
};

/*!
* \brief Concat two Arrays.
* \param lhs first Array to be concatenated.
* \param rhs second Array to be concatenated.
* \return The concatenated Array. Original Arrays are kept unchanged.
*/
template <typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {
for (const auto& x : rhs) {
lhs.push_back(x);
}
return std::move(lhs);
}

// Specialize make_object<ArrayNode> to make sure it is correct.
template <>
inline ObjectPtr<ArrayNode> make_object() {
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ struct ExprDeepEqual {
};

/*!
* \brief Find undefined vars in the statment.
* \brief Find undefined vars in the statement.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
TVM_DLL Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Find undefined vars in the expression.
* \param expr The expression to be checked.
* \return Array of undefined vars.
*/
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);

/*!
* \brief Analyze the side effect
* \param expr The expression to be checked.
Expand Down
7 changes: 0 additions & 7 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@ class CanonicalExprNode : public PrimExprNode {
TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode);
};

enum DivMode {
/*! \brief Truncated division. */
kTruncDiv,
/*! \brief Floor division. */
kFloorDiv
};

inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
if (mode == kTruncDiv) {
return truncmod(a, b);
Expand Down
46 changes: 46 additions & 0 deletions src/arith/int_constraints.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,32 @@
namespace tvm {
namespace arith {

Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
const Array<PrimExpr>& relations) {
Array<PrimExpr> res;
// use variables to keep the order of iteration
// so as to get rid of any non-determinism.
CHECK_EQ(variables.size(), bounds.size());
for (const auto v : variables) {
CHECK(bounds.count(v));
const auto& bnds = bounds[v];
PrimExpr lhs = bnds->coef * v;
for (const PrimExpr& rhs : bnds->equal) {
res.push_back(tir::EQ(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->lower) {
res.push_back(tir::GE(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->upper) {
res.push_back(tir::LE(lhs, rhs));
}
}
for (const PrimExpr& e : relations) {
res.push_back(e);
}
return res;
}

IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
Array<PrimExpr> upper) {
CHECK(coef.dtype().is_int() || coef.dtype().is_uint())
Expand Down Expand Up @@ -231,6 +257,26 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstrai
data_ = std::move(node);
}

IntConstraintsTransform IntConstraintsTransform::operator+(
const IntConstraintsTransform& other) const {
CHECK(other->src.same_as(operator->()->dst));
Map<Var, PrimExpr> dst_to_src;
Map<Var, PrimExpr> src_to_dst;

Analyzer ana_first;
ana_first.Bind(operator->()->src->ranges);
for (auto p : other->dst_to_src) {
dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src)));
}

Analyzer ana_second;
ana_second.Bind(other->dst->ranges);
for (auto p : operator->()->src_to_dst) {
src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst)));
}
return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src);
}

TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);

TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform")
Expand Down
55 changes: 15 additions & 40 deletions src/arith/solve_linear_inequality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,35 +94,6 @@ struct ExprLess {
}
};

/*!
* \brief Combine the information into an array of (in)equalities.
*/
Array<PrimExpr> as_conditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
const Array<PrimExpr>& relations) {
Array<PrimExpr> res;
// use variables to keep the order of iteration
// so as to get rid of any non-determinism.
CHECK_EQ(variables.size(), bounds.size());
for (const auto v : variables) {
CHECK(bounds.count(v));
const auto& bnds = bounds[v];
PrimExpr lhs = bnds->coef * v;
for (const PrimExpr& rhs : bnds->equal) {
res.push_back(tir::EQ(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->lower) {
res.push_back(tir::GE(lhs, rhs));
}
for (const PrimExpr& rhs : bnds->upper) {
res.push_back(tir::LE(lhs, rhs));
}
}
for (const PrimExpr& e : relations) {
res.push_back(e);
}
return res;
}

void DebugPrint(
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
Expand Down Expand Up @@ -290,7 +261,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t

// Simplify each inequality into the form `expr <= 0` and add to current formulas
for (const PrimExpr& ineq : system_to_solve->relations) {
AddInequality(&current_ineq_set_to_solve, NormalizeComparisons()(analyzer.Simplify(ineq, 3)),
AddInequality(&current_ineq_set_to_solve,
NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)),
&analyzer);
}

Expand All @@ -307,8 +279,9 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// Add bounds from vranges
if (system_to_solve->ranges.count(v)) {
const Range& range = system_to_solve->ranges[v];
PrimExpr range_lbound = analyzer.Simplify(range->min, 3);
PrimExpr range_ubound = analyzer.Simplify(range->min + range->extent - 1, 3);
PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite);
PrimExpr range_ubound =
analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite);
coef_neg.push_back({-1, range_lbound});
coef_pos.push_back({1, -range_ubound});
}
Expand All @@ -329,7 +302,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
// to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0
// with steps = 2 it's (y*2) - 10 <= 0
new_ineq = NormalizeComparisons()(analyzer.Simplify(new_ineq, 3));
new_ineq =
NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite));
AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
}
}
Expand All @@ -354,7 +328,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t

for (const auto& pos : coef_pos) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second;
bound = analyzer.Simplify(bound, 3);
bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
// Don't add if any of the existing bounds is better
if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
[&bound, &analyzer](const PrimExpr& o) {
Expand All @@ -375,7 +349,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
}
for (const auto& neg : coef_neg) {
PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
bound = analyzer.Simplify(bound, 3);
bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
// Don't add if any of the existing bounds is better
if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
[&bound, &analyzer](const PrimExpr& o) {
Expand Down Expand Up @@ -414,7 +388,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t
// Everything that is left goes to res.relations
Array<PrimExpr> other_conditions;
for (const PrimExpr& e : current_ineq_set_to_solve) {
PrimExpr e_simp = analyzer.Simplify(e, 3);
PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite);
if (is_const_int(e_simp, 0)) {
// contradiction detected
other_conditions = {const_false()};
Expand Down Expand Up @@ -465,7 +439,8 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
// There is an equation of the form `v == expr`, so this variable can be completely removed.
// Note that we use the 0-th expression because they are ordered by complexity,
// so it must be the simplest one.
Range best_range(bnd->equal[0], analyzer.Simplify(bnd->equal[0] + 1, 3));
Range best_range(bnd->equal[0],
analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite));
res_ranges.Set(var, best_range);
vranges.Set(var, best_range);
} else {
Expand All @@ -491,7 +466,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
arith::Analyzer analyzer;
analyzer.Bind(vranges);
for (const PrimExpr& old_cond :
as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) {
AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
if (!analyzer.CanProve(old_cond)) {
// those not represented in vranges (res_ranges)
res_relations.push_back(old_cond);
Expand Down Expand Up @@ -584,7 +559,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ

// Add the original conditions (with variables substituted) to the resulting conditions
for (const PrimExpr& old_cond :
as_conditions(inequalities->variables, solved_bounds, solved_other_relations)) {
AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst));
if (!is_const_int(new_cond, 1)) {
// those not represented in vranges (res_ranges)
Expand Down Expand Up @@ -615,7 +590,7 @@ TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition")
LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets "
<< args.size();
}
*ret = as_conditions(problem->variables, ret_ineq.first, ret_ineq.second);
*ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second);
});

TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
1 change: 0 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,6 @@ RELAY_REGISTER_OP("scatter_add")
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_support_level(10);

////

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);
Expand Down
Loading

0 comments on commit cd3485f

Please sign in to comment.