Skip to content

Commit

Permalink
[ARITH] Move old modular analysis to Analyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 25, 2019
1 parent 21de9c2 commit 140ee63
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 132 deletions.
39 changes: 19 additions & 20 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ class ConstIntBoundAnalyzer {
* \return the result of the analysis.
*/
ConstIntBound operator()(const Expr& expr);
/*! \brief reset and clear all internal states. */
void Reset();

/*!
* \brief Update constant int bound information of var.
*
Expand All @@ -87,6 +86,13 @@ class ConstIntBoundAnalyzer {
void Update(const Var& var,
const ConstIntBound& info,
bool override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const Var& var, const Range& range);

private:
friend class Analyzer;
Expand Down Expand Up @@ -244,7 +250,17 @@ class Analyzer {
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const Var& var, const Expr& expr);
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
Expand Down Expand Up @@ -513,23 +529,6 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

// Temporary entry for modular
// TODO(tqchen) use Analyzer.
struct ModularEntry {
int64_t coeff{1};
int64_t base{0};
};

/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand Down
7 changes: 6 additions & 1 deletion src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
self->Bind(args[0], args[1]);
auto& sptr = args[1].node_sptr();
if (sptr->is_type<Range::ContainerType>()) {
self->Bind(args[0], args[1].operator Range());
} else {
self->Bind(args[0], args[1].operator Expr());
}
});
} else if (name == "enter_constraint_context") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
Expand Down
9 changes: 8 additions & 1 deletion src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@ Analyzer::Analyzer()
modular_set(this) {
}

void Analyzer::Bind(const Var& var, const Expr& expr) {
void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Var var(v.node_);
this->const_int_bound.Update(var, this->const_int_bound(expr));
this->modular_set.Update(var, this->modular_set(expr));
}

void Analyzer::Bind(const VarExpr& v, const Range& range) {
Var var(v.node_);
this->const_int_bound.Bind(var, range);
// skip modular_set
}

ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) {
// entering the scope.
auto f0 = analyzer->const_int_bound.EnterConstraint(constraint);
Expand Down
23 changes: 21 additions & 2 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,28 @@ struct ConstIntBoundAnalyzer::Entry {
class ConstIntBoundAnalyzer::Impl :
public ExprFunctor<ConstIntBoundAnalyzer::Entry(const Expr&)> {
public:
void Bind(const Var& var, const Range& range) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, false);
}

void Update(const Var& var,
const ConstIntBound& info,
const Entry& info,
bool override) {
if (!override) {
CHECK(!var_map_.count(var));
}
var_map_[var] = MakeBound(info->min_value, info->max_value);
var_map_[var] = info;
}

void Update(const Var& var,
const ConstIntBound& info,
bool override) {
Update(var, MakeBound(info->min_value, info->max_value), override);
}

// Override visitor behaviors
Expand Down Expand Up @@ -358,6 +373,10 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_->Update(var, info, override);
}

void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
impl_->Bind(var, range);
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) {
return nullptr;
}
Expand Down
17 changes: 0 additions & 17 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,22 +340,5 @@ ModularSetAnalyzer::~ModularSetAnalyzer() {
delete impl_;
}


ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
Analyzer ana;
for (const auto& kv : mod_map) {
auto v = kv.second;
ana.modular_set.Update(
GetRef<Var>(kv.first), ModularSetNode::make(v.coeff, v.base));
}
auto mod = ana.modular_set(e);
ModularEntry ret;
ret.coeff = mod->coeff;
ret.base = mod->base;
return ret;
}

} // namespace arith
} // namespace tvm
59 changes: 0 additions & 59 deletions src/codegen/codegen_common.h

This file was deleted.

27 changes: 14 additions & 13 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <tvm/runtime/c_runtime_api.h>
#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "../codegen_common.h"
#include "../../pass/ir_util.h"
#include "../../arithmetic/compute_expr.h"

Expand Down Expand Up @@ -84,9 +83,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
alias_var_set_.clear();
align_map_.clear();
alloc_storage_info_.clear();
volatile_buf_.clear();
analyzer_.reset(new arith::Analyzer());
}

void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
Expand Down Expand Up @@ -381,14 +380,16 @@ void CodeGenLLVM::GetAlignment(Type t,
*p_native_bits = native_vector_bits_;
}

arith::ModularEntry me = arith::EvalModular(index, align_map_);
arith::ModularSet me = analyzer_->modular_set(index);
int64_t base = me->base;
int64_t coeff = me->coeff;

int align_bits = t.bits();
while (align_bits < max_align_bits &&
me.base % 2 == 0 &&
me.coeff % 2 == 0) {
me.base = me.base / 2;
me.coeff = me.coeff / 2;
base % 2 == 0 &&
coeff % 2 == 0) {
base = base / 2;
coeff = coeff / 2;
align_bits *= 2;
}
if (align_bits < 8) {
Expand Down Expand Up @@ -874,7 +875,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}

Expand Down Expand Up @@ -998,6 +999,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) {

void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
if (op->for_type == ForType::Unrolled) {
LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
<< " consider set unroll_explicit=True";
Expand Down Expand Up @@ -1078,6 +1080,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
var_map_[iv->var.get()] = GetThreadIndex(iv);
analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
}
}
} else if (op->attr_key == ir::attr::storage_scope) {
Expand All @@ -1099,21 +1102,19 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}

void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
VisitAssert(op, &align_map_, [this](const Stmt& body) {
this->VisitStmt(body);
});
arith::ConstraintContext cctx(analyzer_.get(), op->condition);
this->VisitStmt(op->body);
}

void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(op->var.get());
}
}
var_map_[op->var.get()] = MakeValue(op->value);
align_map_[op->var.get()] = EvalModular(op->value, align_map_);
analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body);
}

Expand Down
5 changes: 2 additions & 3 deletions src/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ namespace codegen {

using namespace ir;


/*!
* \brief A base class to generate a LLVM.
*/
Expand Down Expand Up @@ -267,8 +266,8 @@ class CodeGenLLVM :
std::unordered_map<std::string, llvm::Constant*> str_map_;
// Whether current function is restricted
bool is_restricted_{true};
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// The analyzer information
std::unique_ptr<arith::Analyzer> analyzer_;
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// set of volatile buffer.
Expand Down
Loading

0 comments on commit 140ee63

Please sign in to comment.