diff --git a/lib/marisa/grimoire/intrin.h b/lib/marisa/grimoire/intrin.h index 77b4e99..7cea33b 100644 --- a/lib/marisa/grimoire/intrin.h +++ b/lib/marisa/grimoire/intrin.h @@ -135,4 +135,9 @@ #endif // MARISA_WORD_SIZE == 64 #endif // _MSC_VER +#if defined(__aarch64__) + #define MARISA_AARCH64 + #include +#endif + #endif // MARISA_GRIMOIRE_INTRIN_H_ diff --git a/lib/marisa/grimoire/vector/bit-vector.cc b/lib/marisa/grimoire/vector/bit-vector.cc index 3bb8b52..73f1c4c 100644 --- a/lib/marisa/grimoire/vector/bit-vector.cc +++ b/lib/marisa/grimoire/vector/bit-vector.cc @@ -173,9 +173,33 @@ const UInt64 MASK_0F = 0x0F0F0F0F0F0F0F0FULL; const UInt64 MASK_33 = 0x3333333333333333ULL; const UInt64 MASK_55 = 0x5555555555555555ULL; #endif // !defined(MARISA_X64) || !defined(MARISA_USE_SSSE3) - #if !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT) const UInt64 MASK_80 = 0x8080808080808080ULL; - #endif // !defined(MARISA_X64) || !defined(MARISA_USE_POPCNT) + +// Pre-computed lookup table trick from Gog, Simon and Matthias Petri. +// "Optimized succinct data structures for massive data." Software: +// Practice and Experience 44 (2014): 1287 - 1314. +// PREFIX_SUM_OVERFLOW[i] = (0x7F - i) * MASK_01. +const UInt64 PREFIX_SUM_OVERFLOW[64] = { + 0x7F * MASK_01, 0x7E * MASK_01, 0x7D * MASK_01, 0x7C * MASK_01, + 0x7B * MASK_01, 0x7A * MASK_01, 0x79 * MASK_01, 0x78 * MASK_01, + 0x77 * MASK_01, 0x76 * MASK_01, 0x75 * MASK_01, 0x74 * MASK_01, + 0x73 * MASK_01, 0x72 * MASK_01, 0x71 * MASK_01, 0x70 * MASK_01, + + 0x6F * MASK_01, 0x6E * MASK_01, 0x6D * MASK_01, 0x6C * MASK_01, + 0x6B * MASK_01, 0x6A * MASK_01, 0x69 * MASK_01, 0x68 * MASK_01, + 0x67 * MASK_01, 0x66 * MASK_01, 0x65 * MASK_01, 0x64 * MASK_01, + 0x63 * MASK_01, 0x62 * MASK_01, 0x61 * MASK_01, 0x60 * MASK_01, + + 0x5F * MASK_01, 0x5E * MASK_01, 0x5D * MASK_01, 0x5C * MASK_01, + 0x5B * MASK_01, 0x5A * MASK_01, 0x59 * MASK_01, 0x58 * MASK_01, + 0x57 * MASK_01, 0x56 * MASK_01, 0x55 * MASK_01, 0x54 * MASK_01, + 0x53 * MASK_01, 0x52 * MASK_01, 0x51 * MASK_01, 0x50 * MASK_01, + + 0x4F * MASK_01, 0x4E * MASK_01, 0x4D * MASK_01, 0x4C * MASK_01, + 0x4B * MASK_01, 0x4A * MASK_01, 0x49 * MASK_01, 0x48 * MASK_01, + 0x47 * MASK_01, 0x46 * MASK_01, 0x45 * MASK_01, 0x44 * MASK_01, + 0x43 * MASK_01, 0x42 * MASK_01, 0x41 * MASK_01, 0x40 * MASK_01 +}; std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { UInt64 counts; @@ -196,11 +220,16 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { counts = static_cast(_mm_cvtsi128_si64( _mm_add_epi8(lower_counts, upper_counts))); - #else // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + #elif defined(MARISA_AARCH64) + // Byte-wise popcount using CNT (plus a lot of conversion noise). + // This actually only requires NEON, not AArch64, but we are already + // in a 64-bit `#ifdef`. + counts = vget_lane_u64(vreinterpret_u64_u8(vcnt_u8(vcreate_u8(unit))), 0); + #else // defined(MARISA_AARCH64) counts = unit - ((unit >> 1) & MASK_55); counts = (counts & MASK_33) + ((counts >> 2) & MASK_33); counts = (counts + (counts >> 4)) & MASK_0F; - #endif // defined(MARISA_X64) && defined(MARISA_USE_SSSE3) + #endif // defined(MARISA_AARCH64) counts *= MASK_01; } @@ -213,12 +242,17 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { skip = (UInt8)PopCount::count(static_cast(_mm_cvtsi128_si64(x))); } #else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) - const UInt64 x = (counts | MASK_80) - ((i + 1) * MASK_01); + const UInt64 x = (counts + PREFIX_SUM_OVERFLOW[i]) & MASK_80; + // We masked with `MASK_80`, so the first bit set is the high bit in the + // byte, therefore `num_trailing_zeros == 8 * byte_nr + 7` and the byte + // number is the number of trailing zeros divided by 8. We just shift off + // the low 7 bits, so `CTZ` gives us the `skip` value we want for the + // number of bits of `counts` to shift. #ifdef _MSC_VER unsigned long skip; - ::_BitScanForward64(&skip, (x & MASK_80) >> 7); + ::_BitScanForward64(&skip, x >> 7); #else // _MSC_VER - const int skip = ::__builtin_ctzll((x & MASK_80) >> 7); + const int skip = ::__builtin_ctzll(x >> 7); #endif // _MSC_VER #endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) @@ -230,7 +264,8 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt64 unit) { } #else // MARISA_WORD_SIZE == 64 #ifdef MARISA_USE_SSE2 -const UInt8 POPCNT_TABLE[256] = { +// Popcount of the byte times eight. +const UInt8 POPCNT_X8_TABLE[256] = { 0, 8, 8, 16, 8, 16, 16, 24, 8, 16, 16, 24, 16, 24, 24, 32, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, 8, 16, 16, 24, 16, 24, 24, 32, 16, 24, 24, 32, 24, 32, 32, 40, @@ -315,7 +350,10 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, { __m128i x = _mm_set1_epi8((UInt8)(i + 1)); x = _mm_cmpgt_epi8(x, accumulated_counts); - skip = POPCNT_TABLE[_mm_movemask_epi8(x)]; + // Since we use `_mm_movemask_epi8`, to move the top bit of every byte, + // popcount times eight gives the original popcount of `x` before the + // movemask. (`_mm_cmpgt_epi8` sets all bits in a byte to 0 or 1.) + skip = POPCNT_X8_TABLE[_mm_movemask_epi8(x)]; } UInt8 byte; @@ -340,33 +378,62 @@ std::size_t select_bit(std::size_t i, std::size_t bit_id, return bit_id + SELECT_TABLE[i][byte]; } #else // MARISA_USE_SSE2 +const UInt8 POPCNT_TABLE[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8 +}; + std::size_t select_bit(std::size_t i, std::size_t bit_id, UInt32 unit_lo, UInt32 unit_hi) { - UInt32 unit = unit_lo; - PopCount count(unit); - if (i >= count.lo32()) { - bit_id += 32; - i -= count.lo32(); - unit = unit_hi; - count = PopCount(unit); - } - - if (i < count.lo16()) { - if (i >= count.lo8()) { - bit_id += 8; - unit >>= 8; - i -= count.lo8(); - } - } else if (i < count.lo24()) { - bit_id += 16; - unit >>= 16; - i -= count.lo16(); - } else { - bit_id += 24; - unit >>= 24; - i -= count.lo24(); - } - return bit_id + SELECT_TABLE[i][unit & 0xFF]; + UInt32 next_byte = unit_lo & 0xFF; + UInt32 byte_popcount = POPCNT_TABLE[next_byte]; + // Assuming the desired bit is in a random byte, branches are not + // taken 7/8 of the time, so this is branch-predictor friendly, + // unlike binary search. + if (i < byte_popcount) return bit_id + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = (unit_lo >> 8) & 0xFF; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 8 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = (unit_lo >> 16) & 0xFF; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 16 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = unit_lo >> 24; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 24 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + + next_byte = unit_hi & 0xFF; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 32 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = (unit_hi >> 8) & 0xFF; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 40 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = (unit_hi >> 16) & 0xFF; + byte_popcount = POPCNT_TABLE[next_byte]; + if (i < byte_popcount) return bit_id + 48 + SELECT_TABLE[i][next_byte]; + i -= byte_popcount; + next_byte = unit_hi >> 24; + // Assume `i < POPCNT_TABLE[next_byte]`. + return bit_id + 56 + SELECT_TABLE[i][next_byte]; } #endif // MARISA_USE_SSE2 diff --git a/lib/marisa/grimoire/vector/pop-count.h b/lib/marisa/grimoire/vector/pop-count.h index 47f4b5d..8347bd4 100644 --- a/lib/marisa/grimoire/vector/pop-count.h +++ b/lib/marisa/grimoire/vector/pop-count.h @@ -51,9 +51,12 @@ class PopCount { #else // _MSC_VER return static_cast(_mm_popcnt_u64(x)); #endif // _MSC_VER -#else // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) +#elif defined(MARISA_AARCH64) + // Byte-wise popcount followed by horizontal add. + return vaddv_u8(vcnt_u8(vcreate_u8(x))); +#else // defined(MARISA_AARCH64) return PopCount(x).lo64(); -#endif // defined(MARISA_X64) && defined(MARISA_USE_POPCNT) +#endif // defined(MARISA_AARCH64) } private: