Skip to content

Commit

Permalink
FIX Check alignment before binaryOp dispatch (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher authored Nov 24, 2020
1 parent eefa69f commit b7c989a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
## Bug Fixes
- PR #77: Fixing CUB include for CUDA < 11
- PR #86: Missing headers for newly moved prims
- PR #102: Check alignment before binaryOp dispatch

# RAFT 0.16.0 (Date TBD)

Expand Down
23 changes: 19 additions & 4 deletions cpp/include/raft/linalg/binary_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ void binaryOpImpl(OutType *out, const InType *in1, const InType *in2,
CUDA_CHECK(cudaPeekAtLastError());
}

/**
* @brief Checks if addresses are aligned on N bytes
*/
inline bool addressAligned(uint64_t addr1, uint64_t addr2, uint64_t addr3,
uint64_t N) {
return addr1 % N == 0 && addr2 % N == 0 && addr3 % N == 0;
}

/**
* @brief perform element-wise binary operation on the input arrays
* @tparam InType input data-type
Expand All @@ -76,16 +84,23 @@ void binaryOp(OutType *out, const InType *in1, const InType *in2, IdxType len,
constexpr auto maxSize =
sizeof(InType) > sizeof(OutType) ? sizeof(InType) : sizeof(OutType);
size_t bytes = len * maxSize;
if (16 / maxSize && bytes % 16 == 0) {
uint64_t in1Addr = uint64_t(in1);
uint64_t in2Addr = uint64_t(in2);
uint64_t outAddr = uint64_t(out);
if (16 / maxSize && bytes % 16 == 0 &&
addressAligned(in1Addr, in2Addr, outAddr, 16)) {
binaryOpImpl<InType, 16 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (8 / maxSize && bytes % 8 == 0) {
} else if (8 / maxSize && bytes % 8 == 0 &&
addressAligned(in1Addr, in2Addr, outAddr, 8)) {
binaryOpImpl<InType, 8 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (4 / maxSize && bytes % 4 == 0) {
} else if (4 / maxSize && bytes % 4 == 0 &&
addressAligned(in1Addr, in2Addr, outAddr, 4)) {
binaryOpImpl<InType, 4 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (2 / maxSize && bytes % 2 == 0) {
} else if (2 / maxSize && bytes % 2 == 0 &&
addressAligned(in1Addr, in2Addr, outAddr, 2)) {
binaryOpImpl<InType, 2 / maxSize, Lambda, IdxType, OutType, TPB>(
out, in1, in2, len, op, stream);
} else if (1 / maxSize) {
Expand Down
31 changes: 31 additions & 0 deletions cpp/test/linalg/binary_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <gtest/gtest.h>
#include <raft/cudart_utils.h>
#include <raft/linalg/binary_op.cuh>
#include <raft/mr/device/buffer.hpp>
#include <raft/random/rng.cuh>
#include "../test_utils.h"
#include "binary_op.cuh"
Expand Down Expand Up @@ -121,5 +122,35 @@ TEST_P(BinaryOpTestD_i64, Result) {
INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestD_i64,
::testing::ValuesIn(inputsd_i64));

template <typename math_t>
class BinaryOpAlignment : public ::testing::Test {
protected:
BinaryOpAlignment() {
CUDA_CHECK(cudaStreamCreate(&stream));
handle.set_stream(stream);
}
void TearDown() override { CUDA_CHECK(cudaStreamDestroy(stream)); }

public:
void Misaligned() {
// Test to trigger cudaErrorMisalignedAddress if veclen is incorrectly
// chosen.
int n = 1024;
mr::device::buffer<math_t> x(handle.get_device_allocator(), stream, n);
mr::device::buffer<math_t> y(handle.get_device_allocator(), stream, n);
mr::device::buffer<math_t> z(handle.get_device_allocator(), stream, n);
CUDA_CHECK(cudaMemsetAsync(x.data(), 0, n * sizeof(math_t), stream));
CUDA_CHECK(cudaMemsetAsync(y.data(), 0, n * sizeof(math_t), stream));
raft::linalg::binaryOp(
z.data() + 9, x.data() + 137, y.data() + 19, 256,
[] __device__(math_t x, math_t y) { return x + y; }, stream);
}

raft::handle_t handle;
cudaStream_t stream;
};
typedef ::testing::Types<float, double> FloatTypes;
TYPED_TEST_CASE(BinaryOpAlignment, FloatTypes);
TYPED_TEST(BinaryOpAlignment, Misaligned) { this->Misaligned(); }
} // namespace linalg
} // namespace raft

0 comments on commit b7c989a

Please sign in to comment.