Skip to content

Commit

Permalink
Add support for module almost_eq check for f16 type (#5261)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Mar 31, 2021
1 parent 1c59bd1 commit 2e05313
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 2 deletions.
2 changes: 2 additions & 0 deletions iree/modules/check/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ cc_test(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@half//:includes",
],
)

Expand All @@ -55,5 +56,6 @@ cc_library(
"//iree/vm:cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@half//:includes",
],
)
2 changes: 2 additions & 0 deletions iree/modules/check/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ if(${IREE_HAL_DRIVER_VMLA})
::native_module
absl::core_headers
absl::strings
half::includes
iree::base::api
iree::base::logging
iree::base::status
Expand All @@ -50,6 +51,7 @@ iree_cc_library(
DEPS
absl::inlined_vector
absl::strings
half::includes
iree::base::api
iree::base::status
iree::hal::api
Expand Down
69 changes: 69 additions & 0 deletions iree/modules/check/check_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "iree/vm/api.h"
#include "iree/vm/bytecode_module.h"
#include "iree/vm/ref_cc.h"
#include "third_party/half/half.hpp"

namespace iree {
namespace {
Expand Down Expand Up @@ -101,6 +102,29 @@ class CheckTest : public ::testing::Test {
&*out_buffer_view));
}

void CreateFloat16BufferView(absl::Span<const uint16_t> contents,
absl::Span<const int32_t> shape,
iree_hal_buffer_view_t** out_buffer_view) {
size_t num_elements = 1;
for (int32_t dim : shape) {
num_elements *= dim;
}
ASSERT_EQ(contents.size(), num_elements);
vm::ref<iree_hal_buffer_t> buffer;
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
allocator_,
static_cast<iree_hal_memory_type_t>(
IREE_HAL_MEMORY_TYPE_HOST_LOCAL |
IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE),
IREE_HAL_BUFFER_USAGE_ALL, contents.size() * sizeof(uint16_t),
&buffer));
IREE_ASSERT_OK(iree_hal_buffer_write_data(
buffer.get(), 0, contents.data(), contents.size() * sizeof(uint16_t)));
IREE_ASSERT_OK(iree_hal_buffer_view_create(
buffer.get(), IREE_HAL_ELEMENT_TYPE_FLOAT_16, shape.data(),
shape.size(), &*out_buffer_view));
}

void CreateFloat32BufferView(absl::Span<const float> contents,
absl::Span<const int32_t> shape,
iree_hal_buffer_view_t** out_buffer_view) {
Expand Down Expand Up @@ -528,5 +552,50 @@ TEST_F(CheckTest, ExpectAlmostEqDifferentContents3DFullMessageFailure) {
" rhs:\n"
" 2x2x2xf32=[[1 2][3 42]][[5 6][7 8]]");
}

TEST_F(CheckTest, ExpectAlmostEqIdenticalBufferF16Success) {
vm::ref<iree_hal_buffer_view_t> lhs;
vm::ref<iree_hal_buffer_view_t> rhs;
uint16_t contents[] = {
half_float::detail::float2half<std::round_to_nearest>(1.f)};
int32_t shape[] = {1};
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(contents, shape, &lhs));
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(contents, shape, &rhs));
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs}));
}

TEST_F(CheckTest, ExpectAlmostEqNearIdenticalBufferF16Success) {
vm::ref<iree_hal_buffer_view_t> lhs;
vm::ref<iree_hal_buffer_view_t> rhs;
uint16_t lhs_contents[] = {
half_float::detail::float2half<std::round_to_nearest>(1.0f),
half_float::detail::float2half<std::round_to_nearest>(1.99999f),
half_float::detail::float2half<std::round_to_nearest>(0.00001f),
half_float::detail::float2half<std::round_to_nearest>(4.0f)};
uint16_t rhs_contents[] = {
half_float::detail::float2half<std::round_to_nearest>(1.00001f),
half_float::detail::float2half<std::round_to_nearest>(2.0f),
half_float::detail::float2half<std::round_to_nearest>(0.0f),
half_float::detail::float2half<std::round_to_nearest>(4.0f)};
int32_t shape[] = {4};
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(lhs_contents, shape, &lhs));
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(rhs_contents, shape, &rhs));
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs}));
}

TEST_F(CheckTest, ExpectAlmostEqDifferentContentsF16Failure) {
vm::ref<iree_hal_buffer_view_t> lhs;
vm::ref<iree_hal_buffer_view_t> rhs;
uint16_t lhs_contents[] = {
half_float::detail::float2half<std::round_to_nearest>(1.f)};
uint16_t rhs_contents[] = {
half_float::detail::float2half<std::round_to_nearest>(2.f)};
int32_t shape[] = {1};
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(lhs_contents, shape, &lhs));
ASSERT_NO_FATAL_FAILURE(CreateFloat16BufferView(rhs_contents, shape, &rhs));
EXPECT_NONFATAL_FAILURE(
IREE_ASSERT_OK(Invoke("expect_almost_eq", {lhs, rhs})),
"Contents does not match");
}
} // namespace
} // namespace iree
23 changes: 21 additions & 2 deletions iree/modules/check/native_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "iree/modules/hal/hal_module.h"
#include "iree/testing/gtest.h"
#include "iree/vm/native_module_cc.h"
#include "third_party/half/half.hpp"

//===----------------------------------------------------------------------===//
// VM module interface implementation
Expand Down Expand Up @@ -70,13 +71,30 @@ bool EqByteSpan(iree_byte_span_t lhs_bytes, iree_byte_span_t rhs_bytes) {
return AbslSpan<uint8_t>(lhs_bytes) == AbslSpan<uint8_t>(rhs_bytes);
}

static constexpr float floatPrecisionThreshold = 0.0001f;

template <typename T>
bool AlmostEqByteSpan(iree_byte_span_t lhs_bytes, iree_byte_span_t rhs_bytes) {
auto lhs_span = AbslSpan<T>(lhs_bytes);
auto rhs_span = AbslSpan<T>(rhs_bytes);
assert(lhs_span.size() == rhs_span.size());
for (int i = 0; i < lhs_span.size(); ++i) {
if (fabs(lhs_span[i] - rhs_span[i]) > 0.0001) {
if (fabs(lhs_span[i] - rhs_span[i]) > floatPrecisionThreshold) {
return false;
}
}
return true;
}

bool AlmostEqByteSpanF16(iree_byte_span_t lhs_bytes,
iree_byte_span_t rhs_bytes) {
auto lhs_span = AbslSpan<uint16_t>(lhs_bytes);
auto rhs_span = AbslSpan<uint16_t>(rhs_bytes);
assert(lhs_span.size() == rhs_span.size());
for (int i = 0; i < lhs_span.size(); ++i) {
if (fabs(half_float::detail::half2float<float>(lhs_span[i]) -
half_float::detail::half2float<float>(rhs_span[i])) >
floatPrecisionThreshold) {
return false;
}
}
Expand All @@ -91,6 +109,8 @@ StatusOr<bool> AlmostEqByteSpan(iree_byte_span_t lhs_bytes,
return AlmostEqByteSpan<float>(lhs_bytes, rhs_bytes);
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
return AlmostEqByteSpan<double>(lhs_bytes, rhs_bytes);
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
return AlmostEqByteSpanF16(lhs_bytes, rhs_bytes);
case IREE_HAL_ELEMENT_TYPE_SINT_8:
case IREE_HAL_ELEMENT_TYPE_UINT_8:
case IREE_HAL_ELEMENT_TYPE_SINT_16:
Expand All @@ -104,7 +124,6 @@ StatusOr<bool> AlmostEqByteSpan(iree_byte_span_t lhs_bytes,
case IREE_HAL_ELEMENT_TYPE_OPAQUE_16:
case IREE_HAL_ELEMENT_TYPE_OPAQUE_32:
case IREE_HAL_ELEMENT_TYPE_OPAQUE_64:
case IREE_HAL_ELEMENT_TYPE_FLOAT_16:
case IREE_HAL_ELEMENT_TYPE_NONE: {
break;
}
Expand Down

0 comments on commit 2e05313

Please sign in to comment.