diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index f8142d9a82..8d1508a3d8 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -53,6 +53,11 @@ void binaryOpImpl(OutType *out, const InType *in1, const InType *in2, CUDA_CHECK(cudaPeekAtLastError()); } +inline bool addressAligned(uint64_t addr1, uint64_t addr2, uint64_t addr3, + uint64_t N) { + return addr1 % N == 0 && addr2 % N == 0 && addr3 % N == 0; +} + /** * @brief perform element-wise binary operation on the input arrays * @tparam InType input data-type @@ -76,16 +81,23 @@ void binaryOp(OutType *out, const InType *in1, const InType *in2, IdxType len, constexpr auto maxSize = sizeof(InType) > sizeof(OutType) ? sizeof(InType) : sizeof(OutType); size_t bytes = len * maxSize; - if (16 / maxSize && bytes % 16 == 0) { + uint64_t in1Addr = uint64_t(in1); + uint64_t in2Addr = uint64_t(in2); + uint64_t outAddr = uint64_t(out); + if (16 / maxSize && bytes % 16 == 0 && + addressAligned(in1Addr, in2Addr, outAddr, 16)) { binaryOpImpl( out, in1, in2, len, op, stream); - } else if (8 / maxSize && bytes % 8 == 0) { + } else if (8 / maxSize && bytes % 8 == 0 && + addressAligned(in1Addr, in2Addr, outAddr, 8)) { binaryOpImpl( out, in1, in2, len, op, stream); - } else if (4 / maxSize && bytes % 4 == 0) { + } else if (4 / maxSize && bytes % 4 == 0 && + addressAligned(in1Addr, in2Addr, outAddr, 4)) { binaryOpImpl( out, in1, in2, len, op, stream); - } else if (2 / maxSize && bytes % 2 == 0) { + } else if (2 / maxSize && bytes % 2 == 0 && + addressAligned(in1Addr, in2Addr, outAddr, 2)) { binaryOpImpl( out, in1, in2, len, op, stream); } else if (1 / maxSize) { diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index 357ade7388..3ae4f86066 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "../test_utils.h" #include "binary_op.cuh" @@ -121,5 +122,35 @@ TEST_P(BinaryOpTestD_i64, Result) { INSTANTIATE_TEST_SUITE_P(BinaryOpTests, BinaryOpTestD_i64, ::testing::ValuesIn(inputsd_i64)); +template +class BinaryOpAlignment : public ::testing::Test { + protected: + BinaryOpAlignment() { + CUDA_CHECK(cudaStreamCreate(&stream)); + handle.set_stream(stream); + } + void TearDown() override { CUDA_CHECK(cudaStreamDestroy(stream)); } + + public: + void Misaligned() { + // Test to trigger cudaErrorMisalignedAddress if veclen is incorrectly + // chosen. + int n = 1024; + mr::device::buffer x(handle.get_device_allocator(), stream, n); + mr::device::buffer y(handle.get_device_allocator(), stream, n); + mr::device::buffer z(handle.get_device_allocator(), stream, n); + CUDA_CHECK(cudaMemsetAsync(x.data(), 0, n * sizeof(math_t), stream)); + CUDA_CHECK(cudaMemsetAsync(y.data(), 0, n * sizeof(math_t), stream)); + raft::linalg::binaryOp( + z.data() + 9, x.data() + 137, y.data() + 19, 256, + [] __device__(math_t x, math_t y) { return x + y; }, stream); + } + + raft::handle_t handle; + cudaStream_t stream; +}; +typedef ::testing::Types FloatTypes; +TYPED_TEST_CASE(BinaryOpAlignment, FloatTypes); +TYPED_TEST(BinaryOpAlignment, Misaligned) { this->Misaligned(); } } // namespace linalg } // namespace raft