Skip to content

Commit

Permalink
Merge pull request #82 from llogiq/aarch64
Browse files Browse the repository at this point in the history
Add aarch64 SIMD specialization
  • Loading branch information
llogiq committed Oct 1, 2023
2 parents fbad8d4 + ffd810a commit b375732
Show file tree
Hide file tree
Showing 9 changed files with 248 additions and 57 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
arch:
- i686
- x86_64
- aarch64
features:
- default
- runtime-dispatch-simd
Expand Down
20 changes: 15 additions & 5 deletions src/integer_simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ unsafe fn usize_load_unchecked(bytes: &[u8], offset: usize) -> usize {
ptr::copy_nonoverlapping(
bytes.as_ptr().add(offset),
&mut output as *mut usize as *mut u8,
mem::size_of::<usize>()
mem::size_of::<usize>(),
);
output
}
Expand Down Expand Up @@ -65,11 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 8
let mut counts = 0;
for i in 0..(haystack.len() - offset) / chunksize {
counts += bytewise_equal(usize_load_unchecked(haystack, offset + i * chunksize), needles);
counts += bytewise_equal(
usize_load_unchecked(haystack, offset + i * chunksize),
needles,
);
}
if haystack.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((haystack.len() % chunksize) * 8)));
counts += bytewise_equal(usize_load_unchecked(haystack, haystack.len() - chunksize), needles) & mask;
counts += bytewise_equal(
usize_load_unchecked(haystack, haystack.len() - chunksize),
needles,
) & mask;
}
count += sum_usize(counts);

Expand Down Expand Up @@ -98,11 +104,15 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// 8
let mut counts = 0;
for i in 0..(utf8_chars.len() - offset) / chunksize {
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
counts +=
is_leading_utf8_byte(usize_load_unchecked(utf8_chars, offset + i * chunksize));
}
if utf8_chars.len() % 8 != 0 {
let mask = usize::from_le(!(!0 >> ((utf8_chars.len() % chunksize) * 8)));
counts += is_leading_utf8_byte(usize_load_unchecked(utf8_chars, utf8_chars.len() - chunksize)) & mask;
counts += is_leading_utf8_byte(usize_load_unchecked(
utf8_chars,
utf8_chars.len() - chunksize,
)) & mask;
}
count += sum_usize(counts);

Expand Down
35 changes: 29 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
//! still on small strings.

#![deny(missing_docs)]

#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)]

#[cfg(not(feature = "runtime-dispatch-simd"))]
Expand All @@ -45,7 +44,11 @@ pub use naive::*;
mod integer_simd;

#[cfg(any(
all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")),
all(
feature = "runtime-dispatch-simd",
any(target_arch = "x86", target_arch = "x86_64")
),
target_arch = "aarch64",
feature = "generic-simd"
))]
mod simd;
Expand All @@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_avx2::chunk_count(haystack, needle);
}
}
}

Expand All @@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_sse2::chunk_count(haystack, needle);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_count(haystack, needle);
}
}
}
Expand Down Expand Up @@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_avx2::chunk_num_chars(utf8_chars);
}
}
}

Expand All @@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_sse2::chunk_num_chars(utf8_chars);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_num_chars(utf8_chars);
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions src/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub fn naive_count_32(haystack: &[u8], needle: u8) -> usize {
/// assert_eq!(number_of_spaces, 6);
/// ```
pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
utf8_chars.iter().fold(0, |n, c| n + (*c == needle) as usize)
utf8_chars
.iter()
.fold(0, |n, c| n + (*c == needle) as usize)
}

/// Count the number of UTF-8 encoded Unicode codepoints in a slice of bytes, simple
Expand All @@ -38,5 +40,8 @@ pub fn naive_count(utf8_chars: &[u8], needle: u8) -> usize {
/// assert_eq!(char_count, 4);
/// ```
pub fn naive_num_chars(utf8_chars: &[u8]) -> usize {
utf8_chars.iter().filter(|&&byte| (byte >> 6) != 0b10).count()
utf8_chars
.iter()
.filter(|&&byte| (byte >> 6) != 0b10)
.count()
}
157 changes: 157 additions & 0 deletions src/simd/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use core::arch::aarch64::{
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
vmvnq_u8, vsubq_u8,
};

const MASK: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255,
];

#[target_feature(enable = "neon")]
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
debug_assert!(
offset + 16 <= slice.len(),
"{} + 16 ≥ {}",
offset,
slice.len()
);
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
}

#[target_feature(enable = "neon")]
unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t {
debug_assert!(
offset + 64 <= slice.len(),
"{} + 64 ≥ {}",
offset,
slice.len()
);
vld1q_u8_x4(slice.as_ptr().add(offset) as *const _)
}

#[target_feature(enable = "neon")]
unsafe fn sum(u8s: uint8x16_t) -> usize {
vaddlvq_u8(u8s) as usize
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
assert!(haystack.len() >= 16);

let mut offset = 0;
let mut count = 0;

let needles = vdupq_n_u8(needle);

// 16320
while haystack.len() >= offset + 64 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..255 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
}

// 64
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..(haystack.len() - offset) / 64 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
}
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);

let mut counts = vdupq_n_u8(0);
// 16
for i in 0..(haystack.len() - offset) / 16 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles),
);
}
if haystack.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles),
u8x16_from_offset(&MASK, haystack.len() % 16),
),
);
}
count + sum(counts)
}

#[target_feature(enable = "neon")]
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
vmvnq_u8(vceqq_u8(
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
vdupq_n_u8(0b1000_0000),
))
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
assert!(utf8_chars.len() >= 16);

let mut offset = 0;
let mut count = 0;

// 4080
while utf8_chars.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);

for _ in 0..255 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(counts);
}

// 2048
if utf8_chars.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(counts);
}

// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
);
}
if utf8_chars.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
),
);
}
count += sum(counts);

count
}
20 changes: 11 additions & 9 deletions src/simd/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use std::mem;
use self::packed_simd::{u8x32, u8x64, FromCast};

const MASK: [u8; 64] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
];

unsafe fn u8x64_from_offset(slice: &[u8], offset: usize) -> u8x64 {
Expand Down Expand Up @@ -66,15 +65,17 @@ pub fn chunk_count(haystack: &[u8], needle: u8) -> usize {
// 32
let mut counts = u8x32::splat(0);
for i in 0..(haystack.len() - offset) / 32 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, offset + i * 32).eq(needles_x32));
}
count += sum_x32(&counts);

// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if haystack.len() % 32 != 0 {
counts -= u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32)) &
u8x32_from_offset(&MASK, haystack.len() % 32);
counts -=
u8x32::from_cast(u8x32_from_offset(haystack, haystack.len() - 32).eq(needles_x32))
& u8x32_from_offset(&MASK, haystack.len() % 32);
}
count += sum_x32(&counts);

Expand Down Expand Up @@ -127,8 +128,9 @@ pub fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
// Straggler; need to reset counts because prior loop can run 255 times
counts = u8x32::splat(0);
if utf8_chars.len() % 32 != 0 {
counts -= is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32)) &
u8x32_from_offset(&MASK, utf8_chars.len() % 32);
counts -=
is_leading_utf8_byte_x32(u8x32_from_offset(utf8_chars, utf8_chars.len() - 32))
& u8x32_from_offset(&MASK, utf8_chars.len() % 32);
}
count += sum_x32(&counts);

Expand Down
4 changes: 4 additions & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ pub mod x86_sse2;
// Runtime feature detection is not available with no_std.
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
pub mod x86_avx2;

/// Modern ARM machines are also quite capable thanks to NEON
#[cfg(target_arch = "aarch64")]
pub mod aarch64;
Loading

0 comments on commit b375732

Please sign in to comment.