Skip to content

Commit

Permalink
Rework the simplifier to use ConstantInterval for bounds (#8222)
Browse files Browse the repository at this point in the history
* Update the simplifier to use ConstantInterval

and track the bounds through more types

* Move the simplify fuzzer back to a correctness test

* Make debug_indent not static

Otherwise it causes a race condition in any parallel tests

* Track expr info on non-overflowing casts to int

* Delete commented-out code

* clang-tidy

* Delete unused member

* Fix cmakelists for the fuzzer removal

* Handle contradictions more gracefully in learn_true

The contradiction was arising from:

if (extent > 0) {
...
} else {
  for (x = 0; x < extent; x++) {
In here we can assume extent > 0, but we also know from the if
statement that extent <= 0
  }
}

* Better comments

* Address review comments

* Fix failure to pop loop var info
  • Loading branch information
abadams authored Jun 2, 2024
1 parent 35143d2 commit a9b8fbf
Show file tree
Hide file tree
Showing 28 changed files with 679 additions and 851 deletions.
70 changes: 33 additions & 37 deletions src/Simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,29 @@ using std::pair;
using std::string;
using std::vector;

#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
int Simplify::debug_indent = 0;
#endif

Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai)
: remove_dead_code(r) {

// Only respect the constant bounds from the containing scope.
for (auto iter = bi->cbegin(); iter != bi->cend(); ++iter) {
ExprInfo bounds;
ExprInfo info;
if (const int64_t *i_min = as_const_int(iter.value().min)) {
bounds.min_defined = true;
bounds.min = *i_min;
info.bounds.min_defined = true;
info.bounds.min = *i_min;
}
if (const int64_t *i_max = as_const_int(iter.value().max)) {
bounds.max_defined = true;
bounds.max = *i_max;
info.bounds.max_defined = true;
info.bounds.max = *i_max;
}

if (const auto *a = ai->find(iter.name())) {
bounds.alignment = *a;
info.alignment = *a;
}

if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) {
bounds_and_alignment_info.push(iter.name(), bounds);
if (info.bounds.min_defined ||
info.bounds.max_defined ||
info.alignment.modulus != 1) {
bounds_and_alignment_info.push(iter.name(), info);
}
}

Expand All @@ -48,20 +46,20 @@ Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemaind
// Already handled
continue;
}
ExprInfo bounds;
bounds.alignment = iter.value();
bounds_and_alignment_info.push(iter.name(), bounds);
ExprInfo info;
info.alignment = iter.value();
bounds_and_alignment_info.push(iter.name(), info);
}
}

std::pair<std::vector<Expr>, bool> Simplify::mutate_with_changes(const std::vector<Expr> &old_exprs, ExprInfo *bounds) {
std::pair<std::vector<Expr>, bool> Simplify::mutate_with_changes(const std::vector<Expr> &old_exprs) {
vector<Expr> new_exprs(old_exprs.size());
bool changed = false;

// Mutate the args
for (size_t i = 0; i < old_exprs.size(); i++) {
const Expr &old_e = old_exprs[i];
Expr new_e = mutate(old_e, bounds);
Expr new_e = mutate(old_e, nullptr);
if (!new_e.same_as(old_e)) {
changed = true;
}
Expand Down Expand Up @@ -135,35 +133,35 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {
Simplify::ExprInfo i;
if (v) {
simplify->mutate(lt->b, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// !(v < i)
learn_lower_bound(v, i.min);
learn_lower_bound(v, i.bounds.min);
}
}
v = lt->b.as<Variable>();
if (v) {
simplify->mutate(lt->a, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// !(i < v)
learn_upper_bound(v, i.max);
learn_upper_bound(v, i.bounds.max);
}
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
Simplify::ExprInfo i;
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(le->b, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// !(v <= i)
learn_lower_bound(v, i.min + 1);
learn_lower_bound(v, i.bounds.min + 1);
}
}
v = le->b.as<Variable>();
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(le->a, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// !(i <= v)
learn_upper_bound(v, i.max - 1);
learn_upper_bound(v, i.bounds.max - 1);
}
}
} else if (const Call *c = Call::as_tag(fact)) {
Expand All @@ -185,8 +183,7 @@ void Simplify::ScopedFact::learn_false(const Expr &fact) {

void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) {
ExprInfo b;
b.max_defined = true;
b.max = val;
b.bounds = ConstantInterval::bounded_above(val);
if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) {
b.intersect(*info);
}
Expand All @@ -196,8 +193,7 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) {

void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) {
ExprInfo b;
b.min_defined = true;
b.min = val;
b.bounds = ConstantInterval::bounded_below(val);
if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) {
b.intersect(*info);
}
Expand Down Expand Up @@ -267,35 +263,35 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) {
Simplify::ExprInfo i;
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(lt->b, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// v < i
learn_upper_bound(v, i.max - 1);
learn_upper_bound(v, i.bounds.max - 1);
}
}
v = lt->b.as<Variable>();
if (v && v->type.is_int() && v->type.bits() >= 32) {
simplify->mutate(lt->a, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// i < v
learn_lower_bound(v, i.min + 1);
learn_lower_bound(v, i.bounds.min + 1);
}
}
} else if (const LE *le = fact.as<LE>()) {
const Variable *v = le->a.as<Variable>();
Simplify::ExprInfo i;
if (v) {
simplify->mutate(le->b, &i);
if (i.max_defined) {
if (i.bounds.max_defined) {
// v <= i
learn_upper_bound(v, i.max);
learn_upper_bound(v, i.bounds.max);
}
}
v = le->b.as<Variable>();
if (v) {
simplify->mutate(le->a, &i);
if (i.min_defined) {
if (i.bounds.min_defined) {
// i <= v
learn_lower_bound(v, i.min);
learn_lower_bound(v, i.bounds.min);
}
}
} else if (const Call *c = Call::as_tag(fact)) {
Expand Down
6 changes: 4 additions & 2 deletions src/Simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ namespace Internal {
* Exprs that should be assumed to be true.
*/
// @{
Stmt simplify(const Stmt &, bool remove_dead_code = true,
Stmt simplify(const Stmt &,
bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expr simplify(const Expr &, bool remove_dead_code = true,
Expr simplify(const Expr &,
bool remove_dead_code = true,
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
const std::vector<Expr> &assumptions = std::vector<Expr>());
Expand Down
28 changes: 12 additions & 16 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@
namespace Halide {
namespace Internal {

Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
ExprInfo a_bounds, b_bounds;
Expr a = mutate(op->a, &a_bounds);
Expr b = mutate(op->b, &b_bounds);

if (bounds && no_overflow_int(op->type)) {
bounds->min_defined = a_bounds.min_defined &&
b_bounds.min_defined &&
add_with_overflow(64, a_bounds.min, b_bounds.min, &(bounds->min));
bounds->max_defined = a_bounds.max_defined &&
b_bounds.max_defined &&
add_with_overflow(64, a_bounds.max, b_bounds.max, &(bounds->max));
bounds->alignment = a_bounds.alignment + b_bounds.alignment;
bounds->trim_bounds_using_alignment();
Expr Simplify::visit(const Add *op, ExprInfo *info) {
ExprInfo a_info, b_info;
Expr a = mutate(op->a, &a_info);
Expr b = mutate(op->b, &b_info);

if (info) {
info->bounds = a_info.bounds + b_info.bounds;
info->alignment = a_info.alignment + b_info.alignment;
info->trim_bounds_using_alignment();
info->cast_to(op->type);
}

if (may_simplify(op->type)) {

// Order commutative operations by node type
if (should_commute(a, b)) {
std::swap(a, b);
std::swap(a_bounds, b_bounds);
std::swap(a_info, b_info);
}

auto rewrite = IRMatcher::rewriter(IRMatcher::add(a, b), op->type);
Expand Down Expand Up @@ -194,7 +190,7 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
rewrite(x + (y + (c0 - x)/c1)*c1, y * c1 - ((c0 - x) % c1) + c0, c1 > 0) ||

false)))) {
return mutate(rewrite.result, bounds);
return mutate(rewrite.result, info);
}
// clang-format on
}
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_And.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace Halide {
namespace Internal {

Expr Simplify::visit(const And *op, ExprInfo *bounds) {
Expr Simplify::visit(const And *op, ExprInfo *info) {
if (falsehoods.count(op)) {
return const_false(op->type.lanes());
}
Expand Down Expand Up @@ -109,7 +109,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) {
rewrite(x <= y && x <= z, x <= min(y, z)) ||
rewrite(y <= x && z <= x, max(y, z) <= x)) {

return mutate(rewrite.result, bounds);
return mutate(rewrite.result, info);
}

if (a.same_as(op->a) &&
Expand Down
Loading

0 comments on commit a9b8fbf

Please sign in to comment.