-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ARITH] Modular Analysis to check if index can be divided by certain …
…value.
- Loading branch information
Showing
6 changed files
with
256 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file modular.cc | ||
* \brief Modular analysis | ||
*/ | ||
#include <tvm/ir.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <limits> | ||
#include "./modular.h" | ||
#include "./int_set_internal.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
// greatest common divisor | ||
using namespace ir; | ||
|
||
class ModularEvaluator : public IRVisitor { | ||
public: | ||
explicit ModularEvaluator( | ||
const std::unordered_map< | ||
const Variable*, ModularEntry>& mod_map) | ||
: mod_map_(mod_map) { | ||
} | ||
// evaluation. | ||
ModularEntry Eval(const Expr& e) { | ||
// always safe to set 0 + x, so it can be everything. | ||
ret_.base = 0; | ||
ret_.coeff = 1; | ||
this->Visit(e); | ||
return ret_; | ||
} | ||
// override combination rules. | ||
void Visit_(const IntImm* op) final { | ||
if (op->value < std::numeric_limits<int>::max()) { | ||
ret_.coeff = 0; | ||
ret_.base = static_cast<int>(op->value); | ||
} | ||
} | ||
void Visit_(const UIntImm* op) final { | ||
if (op->value < static_cast<uint64_t>( | ||
std::numeric_limits<int>::max())) { | ||
ret_.coeff = 0; | ||
ret_.base = static_cast<int>(op->value); | ||
} | ||
} | ||
void Visit_(const Cast* op) final { | ||
// simply use everything. | ||
return; | ||
} | ||
void Visit_(const Variable* op) final { | ||
auto it = mod_map_.find(op); | ||
if (it != mod_map_.end()) { | ||
ret_ = it->second; | ||
} | ||
} | ||
void Visit_(const Add* op) final { | ||
ModularEntry a = Eval(op->a); | ||
ModularEntry b = Eval(op->b); | ||
ret_.coeff = ZeroAwareGCD(a.coeff, b.coeff); | ||
ret_.base = BaseSimplify(a.base + b.base, ret_.coeff); | ||
} | ||
void Visit_(const Sub* op) final { | ||
ModularEntry a = Eval(op->a); | ||
ModularEntry b = Eval(op->b); | ||
ret_.coeff = ZeroAwareGCD(a.coeff, b.coeff); | ||
ret_.base = BaseSimplify(a.base - b.base, ret_.coeff); | ||
} | ||
void Visit_(const Mul* op) final { | ||
ModularEntry a = Eval(op->a); | ||
ModularEntry b = Eval(op->b); | ||
// Simplification rule, x, y, z are in Z | ||
// (p x + n) (q y + m) | ||
// -> pq xy + pm x + qn y + mn | ||
// -> pq z + pm x + qn y + mn | ||
int pq = a.coeff * b.coeff; | ||
int pm = a.coeff * b.base; | ||
int qn = a.base * b.coeff; | ||
ret_.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); | ||
ret_.base = BaseSimplify(a.base * b.base, ret_.coeff); | ||
} | ||
void Visit_(const Div* op) final { | ||
// a c x / c -> a x | ||
// We cannot do cases where offset is non-zero | ||
// because of different integer rounding in pos/neg | ||
ModularEntry a = Eval(op->a); | ||
ModularEntry b = Eval(op->b); | ||
if (b.coeff == 0 && | ||
a.base == 0) { | ||
CHECK_NE(b.base, 0); | ||
if (a.coeff % b.base == 0) { | ||
ret_.coeff = a.coeff / b.base; | ||
ret_.base = 0; | ||
return; | ||
} | ||
} | ||
// default case | ||
ret_.coeff = 1; | ||
ret_.base = 0; | ||
} | ||
|
||
private: | ||
// return value | ||
ModularEntry ret_; | ||
const std::unordered_map< | ||
const Variable*, ModularEntry>& mod_map_; | ||
|
||
// simplify the base by putting it in range. | ||
static int BaseSimplify(int base, int coeff) { | ||
if (coeff == 0) return base; | ||
base = base % coeff; | ||
if (base < 0) base += coeff; | ||
return base; | ||
} | ||
static int ZeroAwareGCD(int a, int b) { | ||
CHECK_GE(a, 0); | ||
CHECK_GE(b, 0); | ||
if (a < b) std::swap(a, b); | ||
if (b == 0) return a; | ||
// perform GCD | ||
// ax + by = gcd(a, b) z if a != 0, b != 0 | ||
while (a % b != 0) { | ||
a = a % b; | ||
std::swap(a, b); | ||
} | ||
return b; | ||
} | ||
}; | ||
|
||
ModularEntry EvalModular( | ||
const Expr& e, | ||
const std::unordered_map<const Variable*, ModularEntry>& mod_map) { | ||
return ModularEvaluator(mod_map).Eval(e); | ||
} | ||
|
||
IntSet EvalModular(const Expr& e, | ||
const Map<Var, IntSet>& mod_map) { | ||
std::unordered_map<const Variable*, ModularEntry> mmap; | ||
for (auto& kv : mod_map) { | ||
const ModularSet* m = kv.second.as<ModularSet>(); | ||
CHECK(m) << "Need to pass ModularSet for Modular Analysis"; | ||
mmap[kv.first.get()] = m->e; | ||
} | ||
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>(); | ||
n->e = ModularEvaluator(mmap).Eval(e); | ||
return IntSet(n); | ||
} | ||
|
||
} // namespace arith | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file modular.h | ||
* \brief Modular integer set analysis | ||
*/ | ||
#ifndef TVM_ARITHMETIC_MODULAR_H_ | ||
#define TVM_ARITHMETIC_MODULAR_H_ | ||
|
||
#include <tvm/expr.h> | ||
#include "./int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
/*! | ||
* \brief Range of a linear integer function. | ||
* Use to do specify the possible index values. | ||
* | ||
* set = { base + coeff * x | x \in Z } | ||
* | ||
* When coeff != 0, it can also be written as | ||
* set = { n | n % coeff == base } | ||
*/ | ||
struct ModularEntry { | ||
/*! \brief The base */ | ||
int base; | ||
/*! \brief linear co-efficient */ | ||
int coeff; | ||
}; | ||
|
||
/*! | ||
* \brief Evaluate the expression with modular analysis | ||
* \param e The expression to be evaluated. | ||
* \param mod_map Map of modular statistics of known variables. | ||
* \return The ModularEntry covering all possible value of e. | ||
*/ | ||
ModularEntry EvalModular( | ||
const Expr& e, | ||
const std::unordered_map<const Variable*, ModularEntry>& mod_map); | ||
/*! | ||
* \brief Same as EvalModular, used by front-end. | ||
* \param e The expression to be evaluated. | ||
* \param mod_map Map of modular statistics of known variables. | ||
* \return A ModularSet covering all possible value of e. | ||
*/ | ||
IntSet EvalModular(const Expr& e, | ||
const Map<Var, IntSet>& mod_map); | ||
} // namespace arith | ||
} // namespace tvm | ||
#endif // TVM_ARITHMETIC_MODULAR_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import tvm | ||
|
||
def test_basic(): | ||
a = tvm.Var() | ||
b = tvm.Var() | ||
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7) | ||
assert m.coeff == 2 | ||
assert m.base == 1 | ||
|
||
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3)) | ||
assert m.coeff == 4 | ||
assert m.base == 3 | ||
|
||
m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3)) | ||
assert m.coeff == 1 | ||
assert m.base == 0 | ||
|
||
m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4)) | ||
assert m.coeff == 2 | ||
assert m.base == 0 | ||
|
||
m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2)) | ||
assert m.coeff == 3 | ||
assert m.base == 2 | ||
|
||
if __name__ == "__main__": | ||
test_basic() |