Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve some SSE and SSE2 tests #1466

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/core_arch/src/x86/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,17 @@ mod tests {
let b = _mm_setr_ps(-100.0, 20.0, 0.0, -5.0);
let r = _mm_max_ps(a, b);
assert_eq_m128(r, _mm_setr_ps(-1.0, 20.0, 0.0, -5.0));

// Check SSE-specific semantics for -0.0 handling.
let a = _mm_setr_ps(-0.0, 0.0, 0.0, 0.0);
let b = _mm_setr_ps(0.0, 0.0, 0.0, 0.0);
let r1: [u8; 16] = transmute(_mm_max_ps(a, b));
let r2: [u8; 16] = transmute(_mm_max_ps(b, a));
let a: [u8; 16] = transmute(a);
let b: [u8; 16] = transmute(b);
assert_eq!(r1, b);
assert_eq!(r2, a);
assert_ne!(a, b); // sanity check that -0.0 is actually present
}

#[simd_test(enable = "sse")]
Expand Down
212 changes: 148 additions & 64 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3343,83 +3343,124 @@ mod tests {

#[simd_test(enable = "sse2")]
unsafe fn test_mm_slli_epi16() {
#[rustfmt::skip]
let a = _mm_setr_epi16(
0xFFFF as u16 as i16, 0x0FFF, 0x00FF, 0x000F, 0, 0, 0, 0,
);
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_slli_epi16::<4>(a);

#[rustfmt::skip]
let e = _mm_setr_epi16(
0xFFF0 as u16 as i16, 0xFFF0 as u16 as i16, 0x0FF0, 0x00F0,
0, 0, 0, 0,
assert_eq_m128i(
r,
_mm_setr_epi16(0xCC0, -0xCC0, 0xDD0, -0xDD0, 0xEE0, -0xEE0, 0xFF0, -0xFF0),
);
assert_eq_m128i(r, e);
let r = _mm_slli_epi16::<16>(a);
assert_eq_m128i(r, _mm_set1_epi16(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_sll_epi16() {
let a = _mm_setr_epi16(0xFF, 0, 0, 0, 0, 0, 0, 0);
let r = _mm_sll_epi16(a, _mm_setr_epi16(4, 0, 0, 0, 0, 0, 0, 0));
assert_eq_m128i(r, _mm_setr_epi16(0xFF0, 0, 0, 0, 0, 0, 0, 0));
let r = _mm_sll_epi16(a, _mm_setr_epi16(0, 0, 0, 0, 4, 0, 0, 0));
assert_eq_m128i(r, _mm_setr_epi16(0xFF, 0, 0, 0, 0, 0, 0, 0));
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_sll_epi16(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(
r,
_mm_setr_epi16(0xCC0, -0xCC0, 0xDD0, -0xDD0, 0xEE0, -0xEE0, 0xFF0, -0xFF0),
);
let r = _mm_sll_epi16(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_sll_epi16(a, _mm_set_epi64x(0, 16));
assert_eq_m128i(r, _mm_set1_epi16(0));
let r = _mm_sll_epi16(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi16(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_slli_epi32() {
let r = _mm_slli_epi32::<4>(_mm_set1_epi32(0xFFFF));
assert_eq_m128i(r, _mm_set1_epi32(0xFFFF0));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_slli_epi32::<4>(a);
assert_eq_m128i(r, _mm_setr_epi32(0xEEEE0, -0xEEEE0, 0xFFFF0, -0xFFFF0));
let r = _mm_slli_epi32::<32>(a);
assert_eq_m128i(r, _mm_set1_epi32(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_sll_epi32() {
let a = _mm_set1_epi32(0xFFFF);
let b = _mm_setr_epi32(4, 0, 0, 0);
let r = _mm_sll_epi32(a, b);
assert_eq_m128i(r, _mm_set1_epi32(0xFFFF0));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_sll_epi32(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(r, _mm_setr_epi32(0xEEEE0, -0xEEEE0, 0xFFFF0, -0xFFFF0));
let r = _mm_sll_epi32(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_sll_epi32(a, _mm_set_epi64x(0, 32));
assert_eq_m128i(r, _mm_set1_epi32(0));
let r = _mm_sll_epi32(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi32(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_slli_epi64() {
let r = _mm_slli_epi64::<4>(_mm_set1_epi64x(0xFFFFFFFF));
assert_eq_m128i(r, _mm_set1_epi64x(0xFFFFFFFF0));
let a = _mm_set_epi64x(0xFFFFFFFF, -0xFFFFFFFF);
let r = _mm_slli_epi64::<4>(a);
assert_eq_m128i(r, _mm_set_epi64x(0xFFFFFFFF0, -0xFFFFFFFF0));
let r = _mm_slli_epi64::<64>(a);
assert_eq_m128i(r, _mm_set1_epi64x(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_sll_epi64() {
let a = _mm_set1_epi64x(0xFFFFFFFF);
let b = _mm_setr_epi64x(4, 0);
let r = _mm_sll_epi64(a, b);
assert_eq_m128i(r, _mm_set1_epi64x(0xFFFFFFFF0));
let a = _mm_set_epi64x(0xFFFFFFFF, -0xFFFFFFFF);
let r = _mm_sll_epi64(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(r, _mm_set_epi64x(0xFFFFFFFF0, -0xFFFFFFFF0));
let r = _mm_sll_epi64(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_sll_epi64(a, _mm_set_epi64x(0, 64));
assert_eq_m128i(r, _mm_set1_epi64x(0));
let r = _mm_sll_epi64(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi64x(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srai_epi16() {
let r = _mm_srai_epi16::<1>(_mm_set1_epi16(-1));
assert_eq_m128i(r, _mm_set1_epi16(-1));
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_srai_epi16::<4>(a);
assert_eq_m128i(
r,
_mm_setr_epi16(0xC, -0xD, 0xD, -0xE, 0xE, -0xF, 0xF, -0x10),
);
let r = _mm_srai_epi16::<16>(a);
assert_eq_m128i(r, _mm_setr_epi16(0, -1, 0, -1, 0, -1, 0, -1));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_sra_epi16() {
let a = _mm_set1_epi16(-1);
let b = _mm_setr_epi16(1, 0, 0, 0, 0, 0, 0, 0);
let r = _mm_sra_epi16(a, b);
assert_eq_m128i(r, _mm_set1_epi16(-1));
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_sra_epi16(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(
r,
_mm_setr_epi16(0xC, -0xD, 0xD, -0xE, 0xE, -0xF, 0xF, -0x10),
);
let r = _mm_sra_epi16(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_sra_epi16(a, _mm_set_epi64x(0, 16));
assert_eq_m128i(r, _mm_setr_epi16(0, -1, 0, -1, 0, -1, 0, -1));
let r = _mm_sra_epi16(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_setr_epi16(0, -1, 0, -1, 0, -1, 0, -1));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srai_epi32() {
let r = _mm_srai_epi32::<1>(_mm_set1_epi32(-1));
assert_eq_m128i(r, _mm_set1_epi32(-1));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_srai_epi32::<4>(a);
assert_eq_m128i(r, _mm_setr_epi32(0xEEE, -0xEEF, 0xFFF, -0x1000));
let r = _mm_srai_epi32::<32>(a);
assert_eq_m128i(r, _mm_setr_epi32(0, -1, 0, -1));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_sra_epi32() {
let a = _mm_set1_epi32(-1);
let b = _mm_setr_epi32(1, 0, 0, 0);
let r = _mm_sra_epi32(a, b);
assert_eq_m128i(r, _mm_set1_epi32(-1));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_sra_epi32(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(r, _mm_setr_epi32(0xEEE, -0xEEF, 0xFFF, -0x1000));
let r = _mm_sra_epi32(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_sra_epi32(a, _mm_set_epi64x(0, 32));
assert_eq_m128i(r, _mm_setr_epi32(0, -1, 0, -1));
let r = _mm_sra_epi32(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_setr_epi32(0, -1, 0, -1));
}

#[simd_test(enable = "sse2")]
Expand Down Expand Up @@ -3453,53 +3494,74 @@ mod tests {

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srli_epi16() {
#[rustfmt::skip]
let a = _mm_setr_epi16(
0xFFFF as u16 as i16, 0x0FFF, 0x00FF, 0x000F, 0, 0, 0, 0,
);
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_srli_epi16::<4>(a);
#[rustfmt::skip]
let e = _mm_setr_epi16(
0xFFF as u16 as i16, 0xFF as u16 as i16, 0xF, 0, 0, 0, 0, 0,
assert_eq_m128i(
r,
_mm_setr_epi16(0xC, 0xFF3, 0xD, 0xFF2, 0xE, 0xFF1, 0xF, 0xFF0),
);
assert_eq_m128i(r, e);
let r = _mm_srli_epi16::<16>(a);
assert_eq_m128i(r, _mm_set1_epi16(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srl_epi16() {
let a = _mm_setr_epi16(0xFF, 0, 0, 0, 0, 0, 0, 0);
let r = _mm_srl_epi16(a, _mm_setr_epi16(4, 0, 0, 0, 0, 0, 0, 0));
assert_eq_m128i(r, _mm_setr_epi16(0xF, 0, 0, 0, 0, 0, 0, 0));
let r = _mm_srl_epi16(a, _mm_setr_epi16(0, 0, 0, 0, 4, 0, 0, 0));
assert_eq_m128i(r, _mm_setr_epi16(0xFF, 0, 0, 0, 0, 0, 0, 0));
let a = _mm_setr_epi16(0xCC, -0xCC, 0xDD, -0xDD, 0xEE, -0xEE, 0xFF, -0xFF);
let r = _mm_srl_epi16(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(
r,
_mm_setr_epi16(0xC, 0xFF3, 0xD, 0xFF2, 0xE, 0xFF1, 0xF, 0xFF0),
);
let r = _mm_srl_epi16(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_srl_epi16(a, _mm_set_epi64x(0, 16));
assert_eq_m128i(r, _mm_set1_epi16(0));
let r = _mm_srl_epi16(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi16(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srli_epi32() {
let r = _mm_srli_epi32::<4>(_mm_set1_epi32(0xFFFF));
assert_eq_m128i(r, _mm_set1_epi32(0xFFF));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_srli_epi32::<4>(a);
assert_eq_m128i(r, _mm_setr_epi32(0xEEE, 0xFFFF111, 0xFFF, 0xFFFF000));
let r = _mm_srli_epi32::<32>(a);
assert_eq_m128i(r, _mm_set1_epi32(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srl_epi32() {
let a = _mm_set1_epi32(0xFFFF);
let b = _mm_setr_epi32(4, 0, 0, 0);
let r = _mm_srl_epi32(a, b);
assert_eq_m128i(r, _mm_set1_epi32(0xFFF));
let a = _mm_setr_epi32(0xEEEE, -0xEEEE, 0xFFFF, -0xFFFF);
let r = _mm_srl_epi32(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(r, _mm_setr_epi32(0xEEE, 0xFFFF111, 0xFFF, 0xFFFF000));
let r = _mm_srl_epi32(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_srl_epi32(a, _mm_set_epi64x(0, 32));
assert_eq_m128i(r, _mm_set1_epi32(0));
let r = _mm_srl_epi32(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi32(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srli_epi64() {
let r = _mm_srli_epi64::<4>(_mm_set1_epi64x(0xFFFFFFFF));
assert_eq_m128i(r, _mm_set1_epi64x(0xFFFFFFF));
let a = _mm_set_epi64x(0xFFFFFFFF, -0xFFFFFFFF);
let r = _mm_srli_epi64::<4>(a);
assert_eq_m128i(r, _mm_set_epi64x(0xFFFFFFF, 0xFFFFFFFF0000000));
let r = _mm_srli_epi64::<64>(a);
assert_eq_m128i(r, _mm_set1_epi64x(0));
}

#[simd_test(enable = "sse2")]
unsafe fn test_mm_srl_epi64() {
let a = _mm_set1_epi64x(0xFFFFFFFF);
let b = _mm_setr_epi64x(4, 0);
let r = _mm_srl_epi64(a, b);
assert_eq_m128i(r, _mm_set1_epi64x(0xFFFFFFF));
let a = _mm_set_epi64x(0xFFFFFFFF, -0xFFFFFFFF);
let r = _mm_srl_epi64(a, _mm_set_epi64x(0, 4));
assert_eq_m128i(r, _mm_set_epi64x(0xFFFFFFF, 0xFFFFFFFF0000000));
let r = _mm_srl_epi64(a, _mm_set_epi64x(4, 0));
assert_eq_m128i(r, a);
let r = _mm_srl_epi64(a, _mm_set_epi64x(0, 64));
assert_eq_m128i(r, _mm_set1_epi64x(0));
let r = _mm_srl_epi64(a, _mm_set_epi64x(0, i64::MAX));
assert_eq_m128i(r, _mm_set1_epi64x(0));
}

#[simd_test(enable = "sse2")]
Expand Down Expand Up @@ -4055,6 +4117,17 @@ mod tests {
let b = _mm_setr_pd(5.0, 10.0);
let r = _mm_max_pd(a, b);
assert_eq_m128d(r, _mm_setr_pd(5.0, 10.0));

// Check SSE(2)-specific semantics for -0.0 handling.
let a = _mm_setr_pd(-0.0, 0.0);
let b = _mm_setr_pd(0.0, 0.0);
let r1: [u8; 16] = transmute(_mm_max_pd(a, b));
let r2: [u8; 16] = transmute(_mm_max_pd(b, a));
let a: [u8; 16] = transmute(a);
let b: [u8; 16] = transmute(b);
assert_eq!(r1, b);
assert_eq!(r2, a);
assert_ne!(a, b); // sanity check that -0.0 is actually present
}

#[simd_test(enable = "sse2")]
Expand All @@ -4071,6 +4144,17 @@ mod tests {
let b = _mm_setr_pd(5.0, 10.0);
let r = _mm_min_pd(a, b);
assert_eq_m128d(r, _mm_setr_pd(1.0, 2.0));

// Check SSE(2)-specific semantics for -0.0 handling.
let a = _mm_setr_pd(-0.0, 0.0);
let b = _mm_setr_pd(0.0, 0.0);
let r1: [u8; 16] = transmute(_mm_min_pd(a, b));
let r2: [u8; 16] = transmute(_mm_min_pd(b, a));
let a: [u8; 16] = transmute(a);
let b: [u8; 16] = transmute(b);
assert_eq!(r1, b);
assert_eq!(r2, a);
assert_ne!(a, b); // sanity check that -0.0 is actually present
}

#[simd_test(enable = "sse2")]
Expand Down
9 changes: 9 additions & 0 deletions crates/core_arch/src/x86/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
use crate::core_arch::x86::*;
use std::mem::transmute;

#[track_caller]
#[target_feature(enable = "sse2")]
pub unsafe fn assert_eq_m128i(a: __m128i, b: __m128i) {
assert_eq!(transmute::<_, [u64; 2]>(a), transmute::<_, [u64; 2]>(b))
}

#[track_caller]
#[target_feature(enable = "sse2")]
pub unsafe fn assert_eq_m128d(a: __m128d, b: __m128d) {
if _mm_movemask_pd(_mm_cmpeq_pd(a, b)) != 0b11 {
Expand All @@ -20,6 +22,7 @@ pub unsafe fn get_m128d(a: __m128d, idx: usize) -> f64 {
transmute::<_, [f64; 2]>(a)[idx]
}

#[track_caller]
#[target_feature(enable = "sse")]
pub unsafe fn assert_eq_m128(a: __m128, b: __m128) {
let r = _mm_cmpeq_ps(a, b);
Expand All @@ -40,11 +43,13 @@ pub unsafe fn _mm_setr_epi64x(a: i64, b: i64) -> __m128i {
_mm_set_epi64x(b, a)
}

#[track_caller]
#[target_feature(enable = "avx")]
pub unsafe fn assert_eq_m256i(a: __m256i, b: __m256i) {
assert_eq!(transmute::<_, [u64; 4]>(a), transmute::<_, [u64; 4]>(b))
}

#[track_caller]
#[target_feature(enable = "avx")]
pub unsafe fn assert_eq_m256d(a: __m256d, b: __m256d) {
let cmp = _mm256_cmp_pd::<_CMP_EQ_OQ>(a, b);
Expand All @@ -58,6 +63,7 @@ pub unsafe fn get_m256d(a: __m256d, idx: usize) -> f64 {
transmute::<_, [f64; 4]>(a)[idx]
}

#[track_caller]
#[target_feature(enable = "avx")]
pub unsafe fn assert_eq_m256(a: __m256, b: __m256) {
let cmp = _mm256_cmp_ps::<_CMP_EQ_OQ>(a, b);
Expand Down Expand Up @@ -125,17 +131,20 @@ mod x86_polyfill {
}
pub use self::x86_polyfill::*;

#[track_caller]
pub unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) {
assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b))
}

#[track_caller]
pub unsafe fn assert_eq_m512(a: __m512, b: __m512) {
let cmp = _mm512_cmp_ps_mask::<_CMP_EQ_OQ>(a, b);
if cmp != 0b11111111_11111111 {
panic!("{:?} != {:?}", a, b);
}
}

#[track_caller]
pub unsafe fn assert_eq_m512d(a: __m512d, b: __m512d) {
let cmp = _mm512_cmp_pd_mask::<_CMP_EQ_OQ>(a, b);
if cmp != 0b11111111 {
Expand Down