diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 395abd8c93d6..298ad885fb10 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -1133,7 +1133,10 @@ class BoundsInference : public IRMutator { auto boxes = boxes_provided(body, empty_scope, func_bounds); for (const auto &fused : fused_group) { string fused_stage_name = fused.first + ".s" + std::to_string(fused.second); - boxes_for_fused_group[fused_stage_name] = boxes[fused.first]; + auto it = boxes.find(fused.first); + if (it != boxes.end()) { + boxes_for_fused_group[fused_stage_name] = it->second; + } for (const auto &fn : funcs) { if (fn.name() == fused.first) { stage_name_to_func[fused_stage_name] = fn; diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index 5a5e0d25f0a8..b978d39b4f71 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -2110,6 +2110,93 @@ int rvar_bounds_test() { return 0; } +// Test for the issue described in https://github.com/halide/Halide/issues/6367. +int two_compute_at_test() { + ImageParam input1(Int(16), 2, "input1"); + Func output1("output1"), output2("output2"), output3("output3"); + Var k{"k"}; + + Func intermediate{"intermediate"}; + Func output1_value{"output1_value"}; + Func output3_value{"output3_value"}; + + intermediate(k) = input1(k, 0) * input1(k, 1); + output1_value(k) = intermediate(k) * intermediate(k); + output1(k) = output1_value(k); + output2(k) = output1_value(k) + output1_value(k); + output3_value(k) = input1(k, 0) + 2; + output3(k) = output3_value(k); + + Expr num = input1.dim(0).extent(); + input1.dim(0).set_bounds(0, num); + input1.dim(1).set_bounds(0, 2); + output1.output_buffer().dim(0).set_bounds(0, num); + output2.output_buffer().dim(0).set_bounds(0, num); + output3.output_buffer().dim(0).set_bounds(0, num); + + intermediate + .vectorize(k, 8) + .compute_at(output1_value, k) + .bound_storage(k, 8) + .store_in(MemoryType::Register); + + output1_value + .vectorize(k, 8) + .compute_at(output2, k) + .bound_storage(k, 8) + .store_in(MemoryType::Register); + + output1 + .vectorize(k, 8) + .compute_with(output2, k); + + output2 + .vectorize(k, 8); + + output3_value + .vectorize(k, 8) + .compute_at(output3, k) + .bound_storage(k, 8) + .store_in(MemoryType::Register); + + output3 + .vectorize(k, 8) + .compute_with(output2, k); + + Pipeline p({output1, output2, output3}); + p.compile_jit(); + + Buffer in(8, 2); + Buffer o1(8), o2(8), o3(8); + for (int iy = 0; iy < in.height(); iy++) { + for (int ix = 0; ix < in.width(); ix++) { + in(ix, iy) = ix + iy; + } + } + input1.set(in); + p.realize({o1, o2, o3}); + + for (int x = 0; x < 8; x++) { + int val = (x * (x + 1)) * (x * (x + 1)); + if (o1(x) != val) { + printf("o1(%d) = %d instead of %d\n", + x, o1(x), val); + return -1; + } + if (o2(x) != 2 * val) { + printf("o2(%d) = %d instead of %d\n", + x, o2(x), 2 * val); + return -1; + } + if (o3(x) != x + 2) { + printf("o2(%d) = %d instead of %d\n", + x, o3(x), x + 2); + return -1; + } + } + return 0; +} + } // namespace int main(int argc, char **argv) { @@ -2256,6 +2343,11 @@ int main(int argc, char **argv) { return -1; } + printf("Running two_compute_at test\n"); + if (two_compute_at_test() != 0) { + return -1; + } + printf("Success!\n"); return 0; }