Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some changes to simdlib #2885

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions faiss/utils/simdlib_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,6 @@ struct simd16uint16 : simd256bit {
return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
}

bool is_same(simd16uint16 other) const {
const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i);
unsigned bitmask = _mm256_movemask_epi8(pcmp);
return (bitmask == 0xffffffffU);
}

simd16uint16 operator~() const {
return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1)));
}
Expand Down
149 changes: 72 additions & 77 deletions faiss/utils/simdlib_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,13 @@ struct simd16uint16 {
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd16uint16 other) const {
const bool equal0 =
(vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
0xffff);
const bool equal1 =
(vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
0xffff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u16(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u16>();
const auto equal = vandq_u16(equals.val[0], equals.val[1]);
return vminvq_u16(equal) == 0xffffu;
}

simd16uint16 operator~() const {
Expand Down Expand Up @@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
simd16uint16& minIndices,
simd16uint16& maxValues,
simd16uint16& maxIndices) {
const uint16x8x2_t comparison = uint16x8x2_t{
vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
const uint16x8x2_t comparison =
detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vcltq_u16>();

minValues.data = uint16x8x2_t{
vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
minValues = min(candidateValues, currentValues);
minIndices.data = uint16x8x2_t{
vbslq_u16(
comparison.val[0],
Expand All @@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = uint16x8x2_t{
vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues = max(candidateValues, currentValues);
maxIndices.data = uint16x8x2_t{
vbslq_u16(
comparison.val[0],
Expand Down Expand Up @@ -869,13 +864,13 @@ struct simd32uint8 {
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd32uint8 other) const {
const bool equal0 =
(vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
const bool equal1 =
(vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u8(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u8>();
const auto equal = vandq_u8(equals.val[0], equals.val[1]);
return vminvq_u8(equal) == 0xffu;
}
};

Expand Down Expand Up @@ -960,27 +955,28 @@ struct simd8uint32 {
return *this;
}

bool operator==(simd8uint32 other) const {
const auto equals = detail::simdlib::binary_func(data, other.data)
.call<&vceqq_u32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffff;
simd8uint32 operator==(simd8uint32 other) const {
return simd8uint32{detail::simdlib::binary_func(data, other.data)
.call<&vceqq_u32>()};
}

bool operator!=(simd8uint32 other) const {
return !(*this == other);
simd8uint32 operator~() const {
return simd8uint32{
detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd8uint32 other) const {
const bool equal0 =
(vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
0xffffffff);
const bool equal1 =
(vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
0xffffffff);
simd8uint32 operator!=(simd8uint32 other) const {
return ~(*this == other);
}

return equal0 && equal1;
// Checks whether the other holds exactly the same bytes.
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_u32(other.data);
const auto equals = detail::simdlib::binary_func(data, o)
.template call<&vceqq_u32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffffu;
}

void clear() {
Expand Down Expand Up @@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
simd8uint32& minIndices,
simd8uint32& maxValues,
simd8uint32& maxIndices) {
const uint32x4x2_t comparison = uint32x4x2_t{
vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};

minValues.data = uint32x4x2_t{
vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
const uint32x4x2_t comparison =
detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vcltq_u32>();

minValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vminq_u32>();
minIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand All @@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = uint32x4x2_t{
vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vmaxq_u32>();
maxIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand Down Expand Up @@ -1167,28 +1164,25 @@ struct simd8float32 {
return *this;
}

bool operator==(simd8float32 other) const {
const auto equals =
simd8uint32 operator==(simd8float32 other) const {
return simd8uint32{
detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
.call<&vceqq_f32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffff;
.call<&vceqq_f32>()};
}

bool operator!=(simd8float32 other) const {
return !(*this == other);
simd8uint32 operator!=(simd8float32 other) const {
return ~(*this == other);
}

// Checks whether the other holds exactly the same bytes.
bool is_same_as(simd8float32 other) const {
const bool equal0 =
(vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
0xffffffff);
const bool equal1 =
(vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
0xffffffff);

return equal0 && equal1;
template <typename T>
bool is_same_as(T other) const {
const auto o = detail::simdlib::reinterpret_f32(other.data);
const auto equals =
detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
.template call<&vceqq_f32>();
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
return vminvq_u32(equal) == 0xffffffffu;
}

std::string tostring() const {
Expand Down Expand Up @@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
simd8uint32& minIndices,
simd8float32& maxValues,
simd8uint32& maxIndices) {
const uint32x4x2_t comparison = uint32x4x2_t{
vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};

minValues.data = float32x4x2_t{
vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
const uint32x4x2_t comparison =
detail::simdlib::binary_func<::uint32x4x2_t>(
candidateValues.data, currentValues.data)
.call<&vcltq_f32>();

minValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vminq_f32>();
minIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand All @@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
candidateIndices.data.val[1],
currentIndices.data.val[1])};

maxValues.data = float32x4x2_t{
vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
maxValues.data = detail::simdlib::binary_func(
candidateValues.data, currentValues.data)
.call<&vmaxq_f32>();
maxIndices.data = uint32x4x2_t{
vbslq_u32(
comparison.val[0],
Expand Down