From 5c13966f0cd08c963fa729f2c26e9943a13ede5d Mon Sep 17 00:00:00 2001 From: Andre Bogus Date: Sun, 27 Aug 2023 02:02:23 +0200 Subject: [PATCH 1/3] add aarch64 --- src/lib.rs | 35 +++++++++-- src/simd/aarch64.rs | 139 ++++++++++++++++++++++++++++++++++++++++++++ src/simd/mod.rs | 4 ++ 3 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 src/simd/aarch64.rs diff --git a/src/lib.rs b/src/lib.rs index ef4235c..24f4018 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,6 @@ //! still on small strings. #![deny(missing_docs)] - #![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)] #[cfg(not(feature = "runtime-dispatch-simd"))] @@ -45,7 +44,11 @@ pub use naive::*; mod integer_simd; #[cfg(any( - all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")), + all( + feature = "runtime-dispatch-simd", + any(target_arch = "x86", target_arch = "x86_64") + ), + target_arch = "aarch64", feature = "generic-simd" ))] mod simd; @@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize { #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { - unsafe { return simd::x86_avx2::chunk_count(haystack, needle); } + unsafe { + return simd::x86_avx2::chunk_count(haystack, needle); + } } } @@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize { ))] { if is_x86_feature_detected!("sse2") { - unsafe { return simd::x86_sse2::chunk_count(haystack, needle); } + unsafe { + return simd::x86_sse2::chunk_count(haystack, needle); + } + } + } + #[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))] + { + unsafe { + return simd::aarch64::chunk_count(haystack, needle); } } } @@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize { #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { - unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); } + unsafe { + return simd::x86_avx2::chunk_num_chars(utf8_chars); + } } } @@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize { ))] { if is_x86_feature_detected!("sse2") { - unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); } + unsafe { + return simd::x86_sse2::chunk_num_chars(utf8_chars); + } + } + } + #[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))] + { + unsafe { + return simd::aarch64::chunk_num_chars(utf8_chars); } } } diff --git a/src/simd/aarch64.rs b/src/simd/aarch64.rs new file mode 100644 index 0000000..56e8b71 --- /dev/null +++ b/src/simd/aarch64.rs @@ -0,0 +1,139 @@ +use core::arch::aarch64::{ + uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8, +}; + +const MASK: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, +]; + +#[target_feature(enable = "neon")] +unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t { + vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned? +} + +#[target_feature(enable = "neon")] +unsafe fn sum(u8s: &uint8x16_t) -> usize { + vaddlvq_u8(*u8s) as usize +} + +#[target_feature(enable = "neon")] +pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { + assert!(haystack.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + let needles = vdupq_n_u8(needle); + + // 4080 + while haystack.len() >= offset + 16 * 255 { + let mut counts = vdupq_n_u8(0); + for _ in 0..255 { + counts = vsubq_u8( + counts, + vceqq_u8(u8x16_from_offset(haystack, offset), needles), + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if haystack.len() >= offset + 16 * 128 { + let mut counts = vdupq_n_u8(0); + for _ in 0..128 { + counts = vsubq_u8( + counts, + vceqq_u8(u8x16_from_offset(haystack, offset), needles), + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = vdupq_n_u8(0); + for i in 0..(haystack.len() - offset) / 16 { + counts = vsubq_u8( + counts, + vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles), + ); + } + if haystack.len() % 16 != 0 { + counts = vsubq_u8( + counts, + vandq_u8( + vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles), + u8x16_from_offset(&MASK, haystack.len() % 16), + ), + ); + } + count += sum(&counts); + + count +} + +#[target_feature(enable = "neon")] +unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t { + vmvnq_u8(vceqq_u8( + vandq_u8(u8s, vdupq_n_u8(0b1100_0000)), + vdupq_n_u8(0b1000_0000), + )) +} + +#[target_feature(enable = "neon")] +pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + assert!(utf8_chars.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + // 4080 + while utf8_chars.len() >= offset + 16 * 255 { + let mut counts = vdupq_n_u8(0); + + for _ in 0..255 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)), + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if utf8_chars.len() >= offset + 16 * 128 { + let mut counts = vdupq_n_u8(0); + for _ in 0..128 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)), + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = vdupq_n_u8(0); + for i in 0..(utf8_chars.len() - offset) / 16 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)), + ); + } + if utf8_chars.len() % 16 != 0 { + counts = vsubq_u8( + counts, + vandq_u8( + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)), + u8x16_from_offset(&MASK, utf8_chars.len() % 16), + ), + ); + } + count += sum(&counts); + + count +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs index d144e18..fa98575 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -15,3 +15,7 @@ pub mod x86_sse2; // Runtime feature detection is not available with no_std. #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] pub mod x86_avx2; + +/// Modern ARM machines are also quite capable thanks to NEON +#[cfg(target_arch = "aarch64")] +pub mod aarch64; From 5dc85ad00d2f66c2d33191cc5a5f7e1da6d392a7 Mon Sep 17 00:00:00 2001 From: Andre Bogus Date: Sun, 27 Aug 2023 02:02:39 +0200 Subject: [PATCH 2/3] rustfmt --- src/integer_simd.rs | 20 ++++++++++++----- src/naive.rs | 9 ++++++-- src/simd/generic.rs | 20 +++++++++-------- src/simd/x86_avx2.rs | 52 ++++++++++++++++++++------------------------ tests/check.rs | 7 +----- 5 files changed, 57 insertions(+), 51 deletions(-) diff --git a/src/integer_simd.rs b/src/integer_simd.rs index 48f2ee8..0604194 100644 --- a/src/integer_simd.rs +++ b/src/integer_simd.rs @@ -13,7 +13,7 @@ unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize { ptr::copy_nonoverlapping( bytes.as_ptr().add(offset), &mut output as *mut usize as *mut u8, - mem::size_of::() + mem::size_of::(), ); output } @@ -65,11 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize { // 8 let mut counts = 0; for i in 0..(haystack.len() - offset) / chunksize { - counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles); + counts += bytewise_equal( + usize_load_unchecked(haystack, offset + i * chunksize), + needles, + ); } if haystack.len() % 8 != 0 { let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8))); - counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask; + counts += bytewise_equal( + usize_load_unchecked(haystack, haystack.len() - chunksize), + needles, + ) & mask; } count += sum_usize(counts); @@ -98,11 +104,15 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize { // 8 let mut counts = 0; for i in 0..(utf8_chars.len() - offset) / chunksize { - counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize)); + counts += + is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize)); } if utf8_chars.len() % 8 != 0 { let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8))); - counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask; + counts += is_leading_utf8_byte(usize_load_unchecked( + utf8_chars, + utf8_chars.len() - chunksize, + )) & mask; } count += sum_usize(counts); diff --git a/src/naive.rs b/src/naive.rs index 315c4b6..e3f6cf6 100644 --- a/src/naive.rs +++ b/src/naive.rs @@ -22,7 +22,9 @@ pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize { /// assert_eq!(number_of_spaces, 6); /// ``` pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize { - utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize) + utf8_chars + .iter() + .fold(0, |n, c| n + (*c == needle) as usize) } /// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple @@ -38,5 +40,8 @@ pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize { /// assert_eq!(char_count, 4); /// ``` pub fn naive_num_chars(utf8_chars: &[u8]) -> usize { - utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count() + utf8_chars + .iter() + .filter(|&&byte| (byte >> 6) != 0b10) + .count() } diff --git a/src/simd/generic.rs b/src/simd/generic.rs index 2031e73..640ccd8 100644 --- a/src/simd/generic.rs +++ b/src/simd/generic.rs @@ -8,10 +8,9 @@ use std::mem; use self::packed_simd::{u8x32, u8x64, FromCast}; const MASK: [u8; 64] = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]; unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 { @@ -66,15 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize { // 32 let mut counts = u8x32::splat(0); for i in 0..(haystack.len() - offset) / 32 { - counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32)); + counts -= + u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32)); } count += sum_x32(&counts); // Straggler; need to reset counts because prior loop can run 255 times counts = u8x32::splat(0); if haystack.len() % 32 != 0 { - counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) & - u8x32_from_offset(&MASK, haystack.len() % 32); + counts -= + u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) + & u8x32_from_offset(&MASK, haystack.len() % 32); } count += sum_x32(&counts); @@ -127,8 +128,9 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize { // Straggler; need to reset counts because prior loop can run 255 times counts = u8x32::splat(0); if utf8_chars.len() % 32 != 0 { - counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) & - u8x32_from_offset(&MASK, utf8_chars.len() % 32); + counts -= + is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) + & u8x32_from_offset(&MASK, utf8_chars.len() % 32); } count += sum_x32(&counts); diff --git a/src/simd/x86_avx2.rs b/src/simd/x86_avx2.rs index 90a55c0..ea191e2 100644 --- a/src/simd/x86_avx2.rs +++ b/src/simd/x86_avx2.rs @@ -1,14 +1,6 @@ use std::arch::x86_64::{ - __m256i, - _mm256_and_si256, - _mm256_cmpeq_epi8, - _mm256_extract_epi64, - _mm256_loadu_si256, - _mm256_sad_epu8, - _mm256_set1_epi8, - _mm256_setzero_si256, - _mm256_sub_epi8, - _mm256_xor_si256, + __m256i, _mm256_and_si256, _mm256_cmpeq_epi8, _mm256_extract_epi64, _mm256_loadu_si256, + _mm256_sad_epu8, _mm256_set1_epi8, _mm256_setzero_si256, _mm256_sub_epi8, _mm256_xor_si256, }; #[target_feature(enable = "avx2")] @@ -22,10 +14,9 @@ pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i { } const MASK: [u8; 64] = [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]; #[target_feature(enable = "avx2")] @@ -36,10 +27,10 @@ unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i { #[target_feature(enable = "avx2")] unsafe fn sum(u8s: &__m256i) -> usize { let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256()); - ( - _mm256_extract_epi64(sums, 0) + _mm256_extract_epi64(sums, 1) + - _mm256_extract_epi64(sums, 2) + _mm256_extract_epi64(sums, 3) - ) as usize + (_mm256_extract_epi64(sums, 0) + + _mm256_extract_epi64(sums, 1) + + _mm256_extract_epi64(sums, 2) + + _mm256_extract_epi64(sums, 3)) as usize } #[target_feature(enable = "avx2")] @@ -57,7 +48,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { for _ in 0..255 { counts = _mm256_sub_epi8( counts, - _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles), ); offset += 32; } @@ -70,7 +61,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { for _ in 0..128 { counts = _mm256_sub_epi8( counts, - _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles), ); offset += 32; } @@ -82,7 +73,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { for i in 0..(haystack.len() - offset) / 32 { counts = _mm256_sub_epi8( counts, - _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles) + _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles), ); } if haystack.len() % 32 != 0 { @@ -90,8 +81,8 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { counts, _mm256_and_si256( _mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles), - mm256_from_offset(&MASK, haystack.len() % 32) - ) + mm256_from_offset(&MASK, haystack.len() % 32), + ), ); } count += sum(&counts); @@ -101,7 +92,10 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { #[target_feature(enable = "avx2")] unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i { - mm256_cmpneq_epi8(_mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)), _mm256_set1_epu8(0b1000_0000)) + mm256_cmpneq_epi8( + _mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)), + _mm256_set1_epu8(0b1000_0000), + ) } #[target_feature(enable = "avx2")] @@ -118,7 +112,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { for _ in 0..255 { counts = _mm256_sub_epi8( counts, - is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)), ); offset += 32; } @@ -131,7 +125,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { for _ in 0..128 { counts = _mm256_sub_epi8( counts, - is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)), ); offset += 32; } @@ -143,7 +137,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { for i in 0..(utf8_chars.len() - offset) / 32 { counts = _mm256_sub_epi8( counts, - is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)) + is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)), ); } if utf8_chars.len() % 32 != 0 { @@ -151,8 +145,8 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { counts, _mm256_and_si256( is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)), - mm256_from_offset(&MASK, utf8_chars.len() % 32) - ) + mm256_from_offset(&MASK, utf8_chars.len() % 32), + ), ); } count += sum(&counts); diff --git a/tests/check.rs b/tests/check.rs index 147b466..5a99950 100644 --- a/tests/check.rs +++ b/tests/check.rs @@ -3,10 +3,7 @@ extern crate bytecount; extern crate quickcheck; extern crate rand; -use bytecount::{ - count, naive_count, - num_chars, naive_num_chars, -}; +use bytecount::{count, naive_count, naive_num_chars, num_chars}; use rand::RngCore; fn random_bytes(len: usize) -> Vec { @@ -59,8 +56,6 @@ fn check_count_overflow_many() { } } - - quickcheck! { fn check_num_chars_correct(haystack: Vec) -> bool { num_chars(&haystack) == naive_num_chars(&haystack) From ffd810aec2b0360b81d0e3c8be335ca37b92e2d6 Mon Sep 17 00:00:00 2001 From: Andre Bogus Date: Sun, 27 Aug 2023 02:06:19 +0200 Subject: [PATCH 3/3] add aarch64 to CI matrix --- .github/workflows/ci.yml | 1 + src/simd/aarch64.rs | 82 ++++++++++++++++++++++++---------------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25c8e55..678caf5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,7 @@ jobs: arch: - i686 - x86_64 + - aarch64 features: - default - runtime-dispatch-simd diff --git a/src/simd/aarch64.rs b/src/simd/aarch64.rs index 56e8b71..6544355 100644 --- a/src/simd/aarch64.rs +++ b/src/simd/aarch64.rs @@ -1,5 +1,6 @@ use core::arch::aarch64::{ - uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8, + uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, + vmvnq_u8, vsubq_u8, }; const MASK: [u8; 32] = [ @@ -9,12 +10,29 @@ const MASK: [u8; 32] = [ #[target_feature(enable = "neon")] unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t { + debug_assert!( + offset + 16 <= slice.len(), + "{} + 16 ≥ {}", + offset, + slice.len() + ); vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned? } #[target_feature(enable = "neon")] -unsafe fn sum(u8s: &uint8x16_t) -> usize { - vaddlvq_u8(*u8s) as usize +unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t { + debug_assert!( + offset + 64 <= slice.len(), + "{} + 64 ≥ {}", + offset, + slice.len() + ); + vld1q_u8_x4(slice.as_ptr().add(offset) as *const _) +} + +#[target_feature(enable = "neon")] +unsafe fn sum(u8s: uint8x16_t) -> usize { + vaddlvq_u8(u8s) as usize } #[target_feature(enable = "neon")] @@ -26,38 +44,40 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { let needles = vdupq_n_u8(needle); - // 4080 - while haystack.len() >= offset + 16 * 255 { - let mut counts = vdupq_n_u8(0); + // 16320 + while haystack.len() >= offset + 64 * 255 { + let (mut count1, mut count2, mut count3, mut count4) = + (vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0)); for _ in 0..255 { - counts = vsubq_u8( - counts, - vceqq_u8(u8x16_from_offset(haystack, offset), needles), - ); - offset += 16; + let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset); + count1 = vsubq_u8(count1, vceqq_u8(h1, needles)); + count2 = vsubq_u8(count2, vceqq_u8(h2, needles)); + count3 = vsubq_u8(count3, vceqq_u8(h3, needles)); + count4 = vsubq_u8(count4, vceqq_u8(h4, needles)); + offset += 64; } - count += sum(&counts); + count += sum(count1) + sum(count2) + sum(count3) + sum(count4); } - // 2048 - if haystack.len() >= offset + 16 * 128 { - let mut counts = vdupq_n_u8(0); - for _ in 0..128 { - counts = vsubq_u8( - counts, - vceqq_u8(u8x16_from_offset(haystack, offset), needles), - ); - offset += 16; - } - count += sum(&counts); + // 64 + let (mut count1, mut count2, mut count3, mut count4) = + (vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0)); + for _ in 0..(haystack.len() - offset) / 64 { + let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset); + count1 = vsubq_u8(count1, vceqq_u8(h1, needles)); + count2 = vsubq_u8(count2, vceqq_u8(h2, needles)); + count3 = vsubq_u8(count3, vceqq_u8(h3, needles)); + count4 = vsubq_u8(count4, vceqq_u8(h4, needles)); + offset += 64; } + count += sum(count1) + sum(count2) + sum(count3) + sum(count4); - // 16 let mut counts = vdupq_n_u8(0); + // 16 for i in 0..(haystack.len() - offset) / 16 { counts = vsubq_u8( counts, - vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles), + vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles), ); } if haystack.len() % 16 != 0 { @@ -69,9 +89,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { ), ); } - count += sum(&counts); - - count + count + sum(counts) } #[target_feature(enable = "neon")] @@ -100,7 +118,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ); offset += 16; } - count += sum(&counts); + count += sum(counts); } // 2048 @@ -113,7 +131,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ); offset += 16; } - count += sum(&counts); + count += sum(counts); } // 16 @@ -121,7 +139,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { for i in 0..(utf8_chars.len() - offset) / 16 { counts = vsubq_u8( counts, - is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)), + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)), ); } if utf8_chars.len() % 16 != 0 { @@ -133,7 +151,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ), ); } - count += sum(&counts); + count += sum(counts); count }