Skip to content

Commit

Permalink
Fix for the crash from #6367 (#6375)
Browse files Browse the repository at this point in the history
* Skip empty boxes

* Address the comments

(cherry picked from commit 5b8f473)
  • Loading branch information
vksnk authored and abadams committed Jan 20, 2022
1 parent d04ca94 commit 3460008
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,10 @@ class BoundsInference : public IRMutator {
auto boxes = boxes_provided(body, empty_scope, func_bounds);
for (const auto &fused : fused_group) {
string fused_stage_name = fused.first + ".s" + std::to_string(fused.second);
boxes_for_fused_group[fused_stage_name] = boxes[fused.first];
auto it = boxes.find(fused.first);
if (it != boxes.end()) {
boxes_for_fused_group[fused_stage_name] = it->second;
}
for (const auto &fn : funcs) {
if (fn.name() == fused.first) {
stage_name_to_func[fused_stage_name] = fn;
Expand Down
92 changes: 92 additions & 0 deletions test/correctness/compute_with.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,93 @@ int rvar_bounds_test() {
return 0;
}

// Test for the issue described in https://github.com/halide/Halide/issues/6367.
int two_compute_at_test() {
ImageParam input1(Int(16), 2, "input1");
Func output1("output1"), output2("output2"), output3("output3");
Var k{"k"};

Func intermediate{"intermediate"};
Func output1_value{"output1_value"};
Func output3_value{"output3_value"};

intermediate(k) = input1(k, 0) * input1(k, 1);
output1_value(k) = intermediate(k) * intermediate(k);
output1(k) = output1_value(k);
output2(k) = output1_value(k) + output1_value(k);
output3_value(k) = input1(k, 0) + 2;
output3(k) = output3_value(k);

Expr num = input1.dim(0).extent();
input1.dim(0).set_bounds(0, num);
input1.dim(1).set_bounds(0, 2);
output1.output_buffer().dim(0).set_bounds(0, num);
output2.output_buffer().dim(0).set_bounds(0, num);
output3.output_buffer().dim(0).set_bounds(0, num);

intermediate
.vectorize(k, 8)
.compute_at(output1_value, k)
.bound_storage(k, 8)
.store_in(MemoryType::Register);

output1_value
.vectorize(k, 8)
.compute_at(output2, k)
.bound_storage(k, 8)
.store_in(MemoryType::Register);

output1
.vectorize(k, 8)
.compute_with(output2, k);

output2
.vectorize(k, 8);

output3_value
.vectorize(k, 8)
.compute_at(output3, k)
.bound_storage(k, 8)
.store_in(MemoryType::Register);

output3
.vectorize(k, 8)
.compute_with(output2, k);

Pipeline p({output1, output2, output3});
p.compile_jit();

Buffer<int16_t> in(8, 2);
Buffer<int16_t> o1(8), o2(8), o3(8);
for (int iy = 0; iy < in.height(); iy++) {
for (int ix = 0; ix < in.width(); ix++) {
in(ix, iy) = ix + iy;
}
}
input1.set(in);
p.realize({o1, o2, o3});

for (int x = 0; x < 8; x++) {
int val = (x * (x + 1)) * (x * (x + 1));
if (o1(x) != val) {
printf("o1(%d) = %d instead of %d\n",
x, o1(x), val);
return -1;
}
if (o2(x) != 2 * val) {
printf("o2(%d) = %d instead of %d\n",
x, o2(x), 2 * val);
return -1;
}
if (o3(x) != x + 2) {
printf("o2(%d) = %d instead of %d\n",
x, o3(x), x + 2);
return -1;
}
}
return 0;
}

} // namespace

int main(int argc, char **argv) {
Expand Down Expand Up @@ -2256,6 +2343,11 @@ int main(int argc, char **argv) {
return -1;
}

printf("Running two_compute_at test\n");
if (two_compute_at_test() != 0) {
return -1;
}

printf("Success!\n");
return 0;
}

0 comments on commit 3460008

Please sign in to comment.