Skip to content

Commit

Permalink
[ARITH] Modular Analysis to check if index can be divided by certain …
Browse files Browse the repository at this point in the history
…value.
  • Loading branch information
tqchen committed Feb 28, 2017
1 parent e438794 commit 898c997
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ._ctypes._node import NodeBase, register_node
from . import _api_internal

@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
Expand Down Expand Up @@ -33,3 +32,8 @@ def max(self):
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass

@register_node
class ModularSet(IntSet):
"""Represent range of (coeff * x + base) for x \in Z """
pass
6 changes: 6 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
#include "../arithmetic/modular.h"

namespace tvm {
namespace arith {
Expand All @@ -21,6 +22,11 @@ TVM_REGISTER_API(_arith_intset_interval)
*ret = IntSet::interval(args[0], args[1]);
});

TVM_REGISTER_API(_arith_EvalModular)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EvalModular(args[0], Map<Var, IntSet>());
});

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
Expand Down
18 changes: 18 additions & 0 deletions src/arithmetic/int_set_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./int_set.h"
#include "./modular.h"

namespace tvm {
namespace arith {
Expand Down Expand Up @@ -54,6 +55,23 @@ struct StrideSet : public IntSetNode {
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};

/*!
* \brief Set represented by range of ModularEntry.
* Used for front-end modular analysis.
*/
struct ModularSet : public IntSetNode {
/*! \brief Internal modular entry */
ModularEntry e;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("base", &(e.base));
v->Visit("coeff", &(e.coeff));
}
static constexpr const char* _type_key = "ModularSet";
TVM_DECLARE_NODE_TYPE_INFO(ModularSet, IntSetNode);
};


} // namespace arith
} // namespace tvm

Expand Down
150 changes: 150 additions & 0 deletions src/arithmetic/modular.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*!
* Copyright (c) 2017 by Contributors
* \file modular.cc
* \brief Modular analysis
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <limits>
#include "./modular.h"
#include "./int_set_internal.h"

namespace tvm {
namespace arith {

// greatest common divisor
using namespace ir;

class ModularEvaluator : public IRVisitor {
public:
explicit ModularEvaluator(
const std::unordered_map<
const Variable*, ModularEntry>& mod_map)
: mod_map_(mod_map) {
}
// evaluation.
ModularEntry Eval(const Expr& e) {
// always safe to set 0 + x, so it can be everything.
ret_.base = 0;
ret_.coeff = 1;
this->Visit(e);
return ret_;
}
// override combination rules.
void Visit_(const IntImm* op) final {
if (op->value < std::numeric_limits<int>::max()) {
ret_.coeff = 0;
ret_.base = static_cast<int>(op->value);
}
}
void Visit_(const UIntImm* op) final {
if (op->value < static_cast<uint64_t>(
std::numeric_limits<int>::max())) {
ret_.coeff = 0;
ret_.base = static_cast<int>(op->value);
}
}
void Visit_(const Cast* op) final {
// simply use everything.
return;
}
void Visit_(const Variable* op) final {
auto it = mod_map_.find(op);
if (it != mod_map_.end()) {
ret_ = it->second;
}
}
void Visit_(const Add* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ret_.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret_.base = BaseSimplify(a.base + b.base, ret_.coeff);
}
void Visit_(const Sub* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
ret_.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret_.base = BaseSimplify(a.base - b.base, ret_.coeff);
}
void Visit_(const Mul* op) final {
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
// Simplification rule, x, y, z are in Z
// (p x + n) (q y + m)
// -> pq xy + pm x + qn y + mn
// -> pq z + pm x + qn y + mn
int pq = a.coeff * b.coeff;
int pm = a.coeff * b.base;
int qn = a.base * b.coeff;
ret_.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret_.base = BaseSimplify(a.base * b.base, ret_.coeff);
}
void Visit_(const Div* op) final {
// a c x / c -> a x
// We cannot do cases where offset is non-zero
// because of different integer rounding in pos/neg
ModularEntry a = Eval(op->a);
ModularEntry b = Eval(op->b);
if (b.coeff == 0 &&
a.base == 0) {
CHECK_NE(b.base, 0);
if (a.coeff % b.base == 0) {
ret_.coeff = a.coeff / b.base;
ret_.base = 0;
return;
}
}
// default case
ret_.coeff = 1;
ret_.base = 0;
}

private:
// return value
ModularEntry ret_;
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;

// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
}
static int ZeroAwareGCD(int a, int b) {
CHECK_GE(a, 0);
CHECK_GE(b, 0);
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD
// ax + by = gcd(a, b) z if a != 0, b != 0
while (a % b != 0) {
a = a % b;
std::swap(a, b);
}
return b;
}
};

ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
return ModularEvaluator(mod_map).Eval(e);
}

IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map) {
std::unordered_map<const Variable*, ModularEntry> mmap;
for (auto& kv : mod_map) {
const ModularSet* m = kv.second.as<ModularSet>();
CHECK(m) << "Need to pass ModularSet for Modular Analysis";
mmap[kv.first.get()] = m->e;
}
std::shared_ptr<ModularSet> n = std::make_shared<ModularSet>();
n->e = ModularEvaluator(mmap).Eval(e);
return IntSet(n);
}

} // namespace arith
} // namespace tvm
50 changes: 50 additions & 0 deletions src/arithmetic/modular.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*!
* Copyright (c) 2017 by Contributors
* \file modular.h
* \brief Modular integer set analysis
*/
#ifndef TVM_ARITHMETIC_MODULAR_H_
#define TVM_ARITHMETIC_MODULAR_H_

#include <tvm/expr.h>
#include "./int_set.h"

namespace tvm {
namespace arith {

/*!
* \brief Range of a linear integer function.
* Use to do specify the possible index values.
*
* set = { base + coeff * x | x \in Z }
*
* When coeff != 0, it can also be written as
* set = { n | n % coeff == base }
*/
struct ModularEntry {
/*! \brief The base */
int base;
/*! \brief linear co-efficient */
int coeff;
};

/*!
* \brief Evaluate the expression with modular analysis
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return The ModularEntry covering all possible value of e.
*/
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map);
/*!
* \brief Same as EvalModular, used by front-end.
* \param e The expression to be evaluated.
* \param mod_map Map of modular statistics of known variables.
* \return A ModularSet covering all possible value of e.
*/
IntSet EvalModular(const Expr& e,
const Map<Var, IntSet>& mod_map);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITHMETIC_MODULAR_H_
27 changes: 27 additions & 0 deletions tests/python/unittest/test_arith_modular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import tvm

def test_basic():
a = tvm.Var()
b = tvm.Var()
m = tvm.arith.EvalModular(a * 4 + b * 6 + 7)
assert m.coeff == 2
assert m.base == 1

m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 + 3))
assert m.coeff == 4
assert m.base == 3

m = tvm.arith.EvalModular((a * 4 + 1) / (b * 8 + 3))
assert m.coeff == 1
assert m.base == 0

m = tvm.arith.EvalModular((a * 4 + 1) * (b * 8 / 4))
assert m.coeff == 2
assert m.base == 0

m = tvm.arith.EvalModular((a * 12 + 1) - (b * 3 * 7 + 2))
assert m.coeff == 3
assert m.base == 2

if __name__ == "__main__":
test_basic()

0 comments on commit 898c997

Please sign in to comment.