diff --git a/lib/cpp/CMakeLists.txt b/lib/cpp/CMakeLists.txt index 6a66e5ad1a7..6a2ecfc6c7d 100644 --- a/lib/cpp/CMakeLists.txt +++ b/lib/cpp/CMakeLists.txt @@ -124,6 +124,58 @@ if(UNIX) endif() endif() +if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + set(PREV_CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -mbmi2 -mbmi -mlzcnt -msse3 -mavx512bw -mavx512vl") + check_cxx_source_compiles( + " + #include + int main(){unsigned int a,b;_pdep_u32(a,b); return 0;} + " + HAVE_BMI2) + check_cxx_source_compiles( + " + #include + int main(){unsigned int a; _tzcnt_u32(a); return 0;} + " + HAVE_BMI) + check_cxx_source_compiles( + " + #include + int main(){unsigned int c;_lzcnt_u32(c); return 0;} + " + HAVE_LZCNT) + check_cxx_source_compiles( + " + #include + int main(){const __m128i* p;_mm_lddqu_si128(p); return 0;} + " + HAVE_SSE3) + check_cxx_source_compiles( + " + #include + int main(){__m128i a,b;_mm_mask_cmp_epi8_mask(0x3ff,a,b,_MM_CMPINT_NLT); return 0;} + " + HAVE_AVX512BW_AVX512VL) + + if (HAVE_BMI2) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mbmi2") + endif() + if (HAVE_BMI) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mbmi") + endif() + if (HAVE_LZCNT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlzcnt") + endif() + if (HAVE_SSE3) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse3") + endif() + if (HAVE_AVX512BW_AVX512VL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512bw -mavx512vl") + endif() + set(CMAKE_REQUIRED_FLAGS ${PREV_CMAKE_REQUIRED_FLAGS}) +endif () + set(thriftcpp_threads_SOURCES src/thrift/concurrency/ThreadFactory.cpp src/thrift/concurrency/Thread.cpp diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.h b/lib/cpp/src/thrift/protocol/TCompactProtocol.h index 792a2d89e3b..7ae19251a5a 100644 --- a/lib/cpp/src/thrift/protocol/TCompactProtocol.h +++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.h @@ -167,7 +167,14 @@ class TCompactProtocolT : public TVirtualProtocol uint32_t writeListEnd() { return 0; } uint32_t writeSetEnd() { return 0; } uint32_t writeFieldEnd() { return 0; } +private: + template + inline __attribute__((always_inline)) uint32_t writeVarint64NoneBMI2(uint64_t n); +#if defined(__BMI2__) && defined(__LZCNT__) + template + inline __attribute__((always_inline)) uint32_t writeVarint64BMI2(uint64_t n); +#endif protected: int32_t writeFieldBeginInternal(const char* name, const TType fieldType, @@ -223,6 +230,19 @@ class TCompactProtocolT : public TVirtualProtocol uint32_t readListEnd() { return 0; } uint32_t readSetEnd() { return 0; } +private: + template + inline __attribute__((always_inline)) uint32_t readVarint64FastPathNoneAVX(const uint8_t* buf,const std::size_t bufsz,int64_t& i64); + template + inline __attribute__((always_inline)) uint32_t readVarint64SlowPathNoneAVX(uint8_t* buf,const std::size_t bufsz,int64_t& i64); + #if defined(__SSE3__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \ + defined(__BMI2__) && defined(__BMI__) + template + inline __attribute__((always_inline)) uint32_t readVarint64FastPathAVX(const uint8_t* buf,const std::size_t bufsz,int64_t& i64); + template + inline __attribute__((always_inline)) uint32_t readVarint64SlowPathAVX(uint8_t* buf,const std::size_t bufsz,int64_t& i64); + #endif + protected: uint32_t readVarint32(int32_t& i32); uint32_t readVarint64(int64_t& i64); diff --git a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc index 9270ab89902..4fa6b03730a 100644 --- a/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc +++ b/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc @@ -23,7 +23,7 @@ #include #include "thrift/config.h" - +#include "immintrin.h" /* * TCompactProtocol::i*ToZigzag depend on the fact that the right shift * operator on a signed integer is an arithmetic (sign-extending) shift. @@ -40,13 +40,66 @@ #ifdef __GNUC__ #define UNLIKELY(val) (__builtin_expect((val), 0)) +#define LIKELY(val) (__builtin_expect((val), 1)) #else #define UNLIKELY(val) (val) +#define LIKELY(val) (val) #endif namespace apache { namespace thrift { namespace protocol { namespace detail { namespace compact { +#if defined(__BMI2__) && defined(__LZCNT__) + +#ifdef __cpp_lib_hardware_interference_size + using std::hardware_constructive_interference_size; +#else + // 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned + // │ ... + constexpr static const std::size_t hardware_constructive_interference_size = + 64; +#endif + + constexpr const static uint8_t Number7Bits alignas( + alignof(hardware_constructive_interference_size))[]{ + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, + 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10}; + + constexpr const static uint64_t + ExtendedBitsMask alignas(alignof( + hardware_constructive_interference_size))[]{0x0, + 0x0, + 0x80, + 0x8080, + 0x808080, + 0x80808080, + 0x8080808080, + 0x808080808080, + 0x80808080808080, + 0x8080808080808080}; + + constexpr const static uint16_t ShiftMasks alignas( + alignof(hardware_constructive_interference_size))[]{ + 0b1, 0b1, 0b11, 0b111, 0b1111, 0b11111, + 0b111111, 0b1111111, 0b11111111, 0b111111111, 0b1111111111}; +#endif + + +#if defined(__SSE3__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \ + defined(__BMI2__) + +constexpr static const uint64_t AvxByteMask[] = {0x0, + 0xff, + 0xffff, + 0xffffff, + 0xffffffff, + 0xffffffffff, + 0xffffffffffff, + 0xffffffffffffff, + 0xffffffffffffffff}; +static const __m128i Zero128I = _mm_setzero_si128(); +#endif enum Types { CT_STOP = 0x00, @@ -362,10 +415,11 @@ uint32_t TCompactProtocolT::writeVarint32(uint32_t n) { } /** - * Write an i64 as a varint. Results in 1-10 bytes on the wire. + * Write an i64 as a varint use none Bmi2 Instruction. Results in 1-10 bytes on the wire. */ template -uint32_t TCompactProtocolT::writeVarint64(uint64_t n) { +template +uint32_t TCompactProtocolT::writeVarint64NoneBMI2(uint64_t n) { uint8_t buf[10]; uint32_t wsize = 0; @@ -378,10 +432,87 @@ uint32_t TCompactProtocolT::writeVarint64(uint64_t n) { n >>= 7; } } - trans_->write(buf, wsize); + if constexpr(needConsume) + trans_->write(buf, wsize); return wsize; } +#if defined(__BMI2__) && defined(__LZCNT__) +/** + * Write an i64 as a varint use BMI2 Instruction. Results in 1-10 bytes on the wire. + */ +template +template +uint32_t TCompactProtocolT::writeVarint64BMI2(uint64_t n) { + uint8_t buf[10]; + + if(n<0x80){ + buf[0] = static_cast(n); + if constexpr(needConsume) + trans_->write(buf, 1); + return 1; + } + + if(n<0x4000){ + buf[0] = static_cast(n)|0x80; + buf[1] = static_cast(n>>7); + if constexpr(needConsume) + trans_->write(buf, 2); + return 2; + } + + if (n < 0x200000) { + buf[0] = static_cast(n)|0x80; + buf[1] = static_cast(n>>7)|0x80; + buf[2] = static_cast(n>>14); + if constexpr(needConsume) + trans_->write(buf, 3); + return 3; + } + using apache::thrift::protocol::detail::compact::ExtendedBitsMask; + using apache::thrift::protocol::detail::compact::Number7Bits; + size_t efftiveBitCounts = 64 - _lzcnt_u64(n); + uint64_t* pdata = reinterpret_cast(buf); + constexpr const uint64_t mask = 0x7f7f7f7f7f7f7f7f; + if (efftiveBitCounts <= 56) { + uint32_t wsize = 0; + wsize = Number7Bits[efftiveBitCounts]; + *pdata = _pdep_u64(n, mask); + *pdata |= ExtendedBitsMask[wsize]; + if constexpr(needConsume) + trans_->write(buf, wsize); + return wsize; + } + + *pdata = _pdep_u64(n & 0xffffffffffffff, mask); + *pdata |= 0x8080808080808080; + if (n < (1ull << 63)) { + buf[8] = static_cast(n >> 56); + if constexpr(needConsume) + trans_->write(buf, 9); + return 9; + } + + buf[8] = static_cast(n >> 56) | 0x80; + buf[9] = 1; + if constexpr(needConsume) + trans_->write(buf, 10); + return 10; +} +#endif + +/** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ +template +uint32_t TCompactProtocolT::writeVarint64(uint64_t n) { +#if defined(__BMI2__) && defined(__LZCNT__) + return writeVarint64BMI2(n); +#else + return writeVarint64NoneBMI2(n); +#endif +} + /** * Convert l into a zigzag long. This allows negative numbers to be * represented compactly as a varint. @@ -718,6 +849,7 @@ uint32_t TCompactProtocolT::readBinary(std::string& str) { return rsize + (uint32_t)size; } + /** * Read an i32 from the wire as a varint. The MSB of each byte is set * if there is another byte to follow. This can read up to 5 bytes. @@ -730,6 +862,142 @@ uint32_t TCompactProtocolT::readVarint32(int32_t& i32) { return rsize; } +/** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ +template +template +uint32_t TCompactProtocolT::readVarint64FastPathNoneAVX(const uint8_t* buf,const std::size_t bufsz,int64_t& i64) { + uint32_t rsize = 0; + int shift = 0; + uint64_t val = 0; + while (LIKELY(rsizeconsume(rsize); + return rsize; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(rsize == 10)) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + return 0; +} + +template +template +uint32_t TCompactProtocolT::readVarint64SlowPathNoneAVX(uint8_t* buf,const std::size_t bufsz,int64_t& i64) { + uint32_t rsize = 0; + int shift = 0; + uint64_t val = 0; + while (LIKELY(rsizereadAll(&byte, 1); + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + if constexpr(needConsume) + trans_->consume(rsize); + return rsize; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(rsize == 10)) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + return 0; + +} + +#if defined(__SSE3__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \ + defined(__BMI2__) && defined(__BMI__) + /** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ +template +template +uint32_t TCompactProtocolT::readVarint64FastPathAVX(const uint8_t*buf,const std::size_t bufsz,int64_t& i64) { + auto ptr = buf; + + if (!(buf[0] & 0x80)) { + i64 = buf[0]; + if constexpr(needConsume) + trans_->consume(1); + return 1; + } + + + if (!(buf[1] & 0x80)) { + i64 = ((buf[1] - 1) << 7)+buf[0]; + if constexpr(needConsume) + trans_->consume(2); + return 2; + } + + + using apache::thrift::protocol::detail::compact::Zero128I; + using apache::thrift::protocol::detail::compact::AvxByteMask; + __m128i data128i = _mm_lddqu_si128(reinterpret_cast(buf)); + const uint8_t* dptr = reinterpret_cast(&data128i); + //"0x3ff" = 10 bytes + __mmask16 msb = + _mm_mask_cmp_epi8_mask(0x3ff, data128i, Zero128I, _MM_CMPINT_NLT); + uint32_t cnt = _tzcnt_u32(msb); + if(UNLIKELY(cnt>=10)){ + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + return 0; + } + auto sz=cnt+1; + if (cnt < 8) { + uint32_t maskIndex = cnt + 1; + uint64_t data64 = + (*(reinterpret_cast(dptr))) & AvxByteMask[maskIndex]; + i64 = _pext_u64(data64, 0x7F7F7F7F7F7F7F7F); + if constexpr(needConsume) + trans_->consume(sz); + return sz; + } + + + uint32_t maskIndex = cnt + 1 - 8; + uint64_t tempValueLo = _pext_u64(*(reinterpret_cast(dptr)), + 0x7F7F7F7F7F7F7F7F); + uint64_t tempValueHi = + _pext_u64(*(reinterpret_cast(dptr + 8)) & + AvxByteMask[maskIndex], + 0x7F7F); + i64 = (tempValueHi << 56) | tempValueLo; + if constexpr(needConsume) + trans_->consume(sz); + return sz; + +} + +template +template +uint32_t TCompactProtocolT::readVarint64SlowPathAVX(uint8_t* buf, const std::size_t bufsz,int64_t& i64) { + uint32_t rsize = 0; + while (LIKELY(rsizereadAll(&buf[rsize], 1); + if (!(buf[rsize++] & 0x80)) + return readVarint64FastPathAVX(buf,bufsz,i64); + } + + // Might as well check for invalid data on the slow path too. + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + return 0; +} +#endif + /** * Read an i64 from the wire as a proper varint. The MSB of each byte is set * if there is another byte to follow. This can read up to 10 bytes. @@ -742,43 +1010,24 @@ uint32_t TCompactProtocolT::readVarint64(int64_t& i64) { uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. uint32_t buf_size = sizeof(buf); const uint8_t* borrowed = trans_->borrow(buf, &buf_size); - // Fast path. if (borrowed != nullptr) { - while (true) { - uint8_t byte = borrowed[rsize]; - rsize++; - val |= (uint64_t)(byte & 0x7f) << shift; - shift += 7; - if (!(byte & 0x80)) { - i64 = val; - trans_->consume(rsize); - return rsize; - } - // Have to check for invalid data so we don't crash. - if (UNLIKELY(rsize == sizeof(buf))) { - throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); - } - } - } +#if defined(__SSE3__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \ + defined(__BMI2__) && defined(__BMI__) + return TCompactProtocolT::readVarint64FastPathAVX(borrowed,10,i64); +#else + return TCompactProtocolT::readVarint64FastPathNoneAVX(borrowed,10,i64); +#endif - // Slow path. - else { - while (true) { - uint8_t byte; - rsize += trans_->readAll(&byte, 1); - val |= (uint64_t)(byte & 0x7f) << shift; - shift += 7; - if (!(byte & 0x80)) { - i64 = val; - return rsize; - } - // Might as well check for invalid data on the slow path too. - if (UNLIKELY(rsize >= sizeof(buf))) { - throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); - } - } } + + +#if defined(__SSE3__) && defined(__AVX512BW__) && defined(__AVX512VL__) && \ + defined(__BMI2__) && defined(__BMI__) + return TCompactProtocolT::readVarint64SlowPathAVX(buf,10,i64); +#else + return TCompactProtocolT::readVarint64SlowPathNoneAVX(buf,10,i64); +#endif } /**