Skip to content

Commit

Permalink
Fix bool conversion bug in Vulkan code generator (halide#8067)
Browse files Browse the repository at this point in the history
* Fix bug in Vulkan code generator that was incorrectly passing the
address of a byte vector, instead of its contents to
builder.declare_constant()

* Add bool_predicate_cast correctness test to verify bool conversion for
Vulkan codegen works as expected

---------

Co-authored-by: Derek Gerstmann <[email protected]>
  • Loading branch information
2 people authored and ardier committed Mar 3, 2024
1 parent e63aeff commit 2185f4e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/CodeGen_Vulkan_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,9 @@ void fill_bytes_with_value(uint8_t *bytes, int count, int value) {
}

SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type value_type, SpvId value_id) {
debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(): casting from value type '"
<< value_type << "' to target type '" << target_type << "' for value id '" << value_id << "' !\n";

if (!value_type.is_bool()) {
value_id = cast_type(Bool(), value_type, value_id);
}
Expand Down Expand Up @@ -590,8 +593,8 @@ SpvId CodeGen_Vulkan_Dev::SPIRV_Emitter::convert_to_bool(Type target_type, Type

SpvId result_id = builder.reserve_id(SpvResultId);
SpvId target_type_id = builder.declare_type(target_type);
SpvId true_value_id = builder.declare_constant(target_type, &true_data);
SpvId false_value_id = builder.declare_constant(target_type, &false_data);
SpvId true_value_id = builder.declare_constant(target_type, &true_data[0]);
SpvId false_value_id = builder.declare_constant(target_type, &false_data[0]);
builder.append(SpvFactory::select(target_type_id, result_id, value_id, true_value_id, false_value_id));
return result_id;
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ tests(GROUPS correctness
bit_counting.cpp
bitwise_ops.cpp
bool_compute_root_vectorize.cpp
bool_predicate_cast.cpp
bound.cpp
bound_small_allocations.cpp
bound_storage.cpp
Expand Down
39 changes: 39 additions & 0 deletions test/correctness/bool_predicate_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {

// Test explicit casting of a predicate to an integer as part of a reduction
// NOTE: triggers a convert_to_bool in Vulkan for a SelectOp
Target target = get_jit_target_from_environment();
Var x("x"), y("y");

Func input("input");
input(x, y) = cast<uint8_t>(x + y);

Func test("test");
test(x, y) = cast(UInt(8), input(x, y) >= 32);

if (target.has_gpu_feature()) {
Var xi("xi"), yi("yi");
test.gpu_tile(x, y, xi, yi, 8, 8);
}

Realization result = test.realize({96, 96});
Buffer<uint8_t> a = result[0];
for (int y = 0; y < a.height(); y++) {
for (int x = 0; x < a.width(); x++) {
uint8_t correct_a = ((x + y) >= 32) ? 1 : 0;
if (a(x, y) != correct_a) {
printf("result(%d, %d) = (%d) instead of (%d)\n",
x, y, a(x, y), correct_a);
return 1;
}
}
}

printf("Success!\n");
return 0;
}

0 comments on commit 2185f4e

Please sign in to comment.