Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Fix intersect of modular set #2904

Merged
merged 1 commit into from
Mar 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 96 additions & 62 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};

Entry() = default;

Entry(int64_t coeff, int64_t base) {
CHECK_GE(coeff, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For one case in test_arith_rewrite_simplify.py, coeff can be smaller than 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this constructor is too "smart", we can create a static function.
see #2726 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, coeff should not be smaller than 0 (see #2726 (comment)). I have no idea about the test cases in test_arith_rewrite_simplify.py but I do pass these tests. Further more, if this case exists in unit test, probably we can try to fix it.

this->coeff = coeff;
if (coeff != 0) {
base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}

bool is_const() const {
return coeff == 0;
}
Expand All @@ -53,10 +65,7 @@ class ModularSetAnalyzer::Impl :
if (!override) {
CHECK(!var_map_.count(var));
}
Entry e;
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
var_map_[var] = Entry(info->coeff, info->base);
}

// Detect useful constraints and use them in the analysis scope.
Expand All @@ -65,9 +74,7 @@ class ModularSetAnalyzer::Impl :
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
Entry entry;
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
return nullptr;
Expand All @@ -83,18 +90,12 @@ class ModularSetAnalyzer::Impl :
}

Entry VisitExpr_(const IntImm* op) final {
Entry ret;
ret.base = op->value;
ret.coeff = 0;
return ret;
return Entry(0, op->value);
}

Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
return Entry(0, static_cast<int>(op->value));
} else {
return Everything();
}
Expand All @@ -103,19 +104,15 @@ class ModularSetAnalyzer::Impl :
Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}

Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}

Entry VisitExpr_(const Mul* op) final {
Expand All @@ -128,10 +125,8 @@ class ModularSetAnalyzer::Impl :
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
Entry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;
int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
return Entry(coeff, a.base * b.base);
}

Entry DivByConst(const Expr& lhs,
Expand All @@ -140,20 +135,15 @@ class ModularSetAnalyzer::Impl :
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) {
// a c x / c -> a x
ret.coeff = std::abs(a.coeff / val);
ret.base = 0;
return ret;
return Entry(std::abs(a.coeff / val), 0);
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val;
ret.base = a.base / val;
return ret;
return Entry(a.coeff / val, a.base / val);
}
}
return Everything();
Expand Down Expand Up @@ -251,41 +241,80 @@ class ModularSetAnalyzer::Impl :
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) {
ret.coeff = coeff;
ret.base = base0;
return ret;
return Entry(coeff, base0);
} else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
ret.base = 0;
return ret;
return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
}
}
/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient.
* \param b The second coefficient.
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
// Extended Euclidean algorithm
// if a < 0, the problem can be convert into
// |a|* (-x) + b * y = gcd(|a|, b)
//
// initial condition:
// a * 0 + b * 1 = b
// a * 1 + b * 0 = a
int64_t s = 0, old_s = 1;
int64_t r = b, old_r = a >= 0 ? a : -a;
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r1 / r2)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
// Because r3 < r2, the iteration can eventually terminate
while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}

*x = a >= 0 ? old_s : -old_s;
if (b != 0) {
*y = (old_r - (*x) * a) / b;
} else {
*y = 1;
}

return old_r;
}
/*!
* \brief Create interect of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Intersect(Entry a, Entry b) {
// simple rule for now: pick higher constraints.
// TODO(team-team): Use extended euclidean algorithm.
if (a.coeff == 0) return a;
if (b.coeff == 0) return b;
if (a.coeff >= b.coeff) return a;
return b;
}
/*!
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
* \param base The base value.
* \param coeff The coeff value.
* \return The simplified base.
*/
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
int64_t x, y;
int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
// z = c1 * p + b1
// z = c2 * q + b2
// c1 * x + c2 * y = gcd(c1, c2)
// -> c1 * p - c2 * q = b2 - b1
// -> p = (b2 - b1) / gcd * x
// -> q = (b2 - b1) / gcd * (-y)
// -> z = LCM(x, y) * k + (c1 * p + b1)
int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y);
int64_t v = b2 - b1;
if (v % gcd == 0) {
x = v / gcd * x;
y = v / gcd * (-y);
int64_t coeff = c1 / gcd * c2;
return Entry(coeff, x * c1 + b1);
} else {
return Nothing();
}
}
/*!
* \brief Take GCD of a and b.
Expand All @@ -311,9 +340,14 @@ class ModularSetAnalyzer::Impl :
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() {
Entry ret;
ret.coeff = 1; ret.base = 0;
return ret;
return Entry(1, 0);
}
/*!
* \brief return an empty set
* \return Bound that represent everything dtype can represent.
*/
static Entry Nothing() {
return Entry(0, 1);
}
};

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_arith_modular_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def test_constraint_scope():
assert m.coeff == 1
assert m.base == 0

def test_intersect():
a = tvm.var("a")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1):
with analyzer.constraint_scope(a % 3 == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1

with analyzer.constraint_scope(a % 3 == 2):
with analyzer.constraint_scope(a % 5 == 3):
with analyzer.constraint_scope(a % 7 == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23


if __name__ == "__main__":
test_cast()
Expand All @@ -126,3 +142,4 @@ def test_constraint_scope():
test_min_max_select()
test_mix_index()
test_constraint_scope()
test_intersect()