Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Analyzer Infra, ConstIntBound, Modular #2668

Merged
merged 5 commits into from
Mar 2, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 267 additions & 56 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,279 @@
#include <vector>
#include <unordered_map>
#include <memory>
#include <limits>
#include "expr.h"

namespace tvm {

// forward delcare Tensor
class Tensor;

/*! \brief namespace of arithmetic */
namespace arith {
//-------------------------------------------------------
// Base integer analysis API.
//
// We have multiple type of analyzers to do relaxed
// integer set analysis(bound analysis, modulo) and
// equivalence checking and simplification.
//
// Importantly, each analyzer may need result from
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
class Analyzer;
/*!
* \brief reference class to ConstIntBoundNode
* \sa ConstIntBoundNode
*/
class ConstIntBound;
/*!
* \brief Constant integer up and lower bound(inclusive).
* Useful for value bound analysis.
*
* set = [min_value, max_value]
*/
class ConstIntBoundNode : public Node {
public:
int64_t min_value;
int64_t max_value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("min_value", &min_value);
v->Visit("max_value", &max_value);
}

TVM_DLL static ConstIntBound make(int64_t min_value, int64_t max_value);

/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*! \brief Number to represent -inf */
static const constexpr int64_t kNegInf = -kPosInf;

This comment was marked as resolved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make sure -kPosInf == kNegInf


static constexpr const char* _type_key = "arith.ConstIntBound";
TVM_DECLARE_NODE_TYPE_INFO(ConstIntBoundNode, Node);
};

TVM_DEFINE_NODE_REF(ConstIntBound, ConstIntBoundNode);

/*!
* \brief Analyzer to get constant integer bound over expression.
*/
class ConstIntBoundAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);

/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief reference of ModularSetNode
* \sa ModularSetNode
*/
class ModularSet;
/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
class ModularSetNode : public Node {
public:
/*! \brief linear co-efficient */
int64_t coeff;
/*! \brief The base */
int64_t base;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("coeff", &coeff);
v->Visit("base", &base);
}

TVM_DLL static ModularSet make(int64_t coeff, int64_t base);

static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSetNode, Node);
};

TVM_DEFINE_NODE_REF(ModularSet, ModularSetNode);

/*!
* \brief Analyzer to get modular information over expression.
*/
class ModularSetAnalyzer {
public:
/*!
* \brief analyze the expr
* \param expr The expression of interest.
* \return the result of the analysis.
*/
ModularSet operator()(const Expr& expr);
/*!
* \brief Update constant int bound information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
const ModularSet& info,
bool override = false);

private:
friend class Analyzer;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const Expr& constraint);
struct Entry;
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief A RAII constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* arith::ConstraintContext cctx(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
* CHECK_NE(analyzer.modular_set(x)->coeff, 3);
*
* \endcode
*/
class ConstraintContext {
public:
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, const Expr& constraint) DMLC_THROW_EXCEPTION;
/*! \brief destructor */
~ConstraintContext() DMLC_THROW_EXCEPTION {
exit_();
}

private:
/*! \brief function to be called in recovery */
std::function<void()> exit_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.

* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set abstraction API.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign of an expression or set.
*/
Expand Down Expand Up @@ -118,42 +383,6 @@ class IntSet : public NodeRef {
static IntSet interval(Expr min, Expr max);
};

/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { coeff * x + base | x in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*
* This is useful to decide if the index is dividable by certain value.
* For example, if index = 0 + 4 x, then we know it can be divided by 4.
*/
struct ModularEntry {
/*! \brief linear co-efficient */
int coeff{1};
/*! \brief The base */
int base{0};

/*! \return entry represent everything */
static ModularEntry everything() {
// always safe to set 0 + x, so it can be everything.
ModularEntry e;
e.coeff = 1;
e.base = 0;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};

/*!
* \brief Base class of all IntSet containers.
*/
Expand Down Expand Up @@ -300,24 +529,6 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);

/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
namespace tvm {
namespace ir {

using HalideIR::Internal::BaseExprNode;
using HalideIR::Internal::ExprNode;
using HalideIR::Internal::StmtNode;
using HalideIR::Internal::IRNodeType;
Expand Down
Loading