Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReduceMod optimization #75

Merged
merged 18 commits into from
Oct 9, 2021
Merged
121 changes: 114 additions & 7 deletions benchmark/bench-eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static void BM_EltwiseReduceModInPlace(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
for (auto _ : state) {
EltwiseReduceMod(input1.data(), input1.data(), input_size, modulus,
Expand All @@ -46,7 +46,7 @@ static void BM_EltwiseReduceModCopy(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

Expand All @@ -71,7 +71,7 @@ static void BM_EltwiseReduceModNative(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

Expand All @@ -93,17 +93,17 @@ BENCHMARK(BM_EltwiseReduceModNative)
// state[0] is the degree
static void BM_EltwiseReduceModAVX512(benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 1152921504606877697;
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512(output.data(), input1.data(), input_size, modulus,
input_mod_factor, output_mod_factor);
EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

Expand All @@ -116,5 +116,112 @@ BENCHMARK(BM_EltwiseReduceModAVX512)

//=================================================================

#ifdef HEXL_HAS_AVX512DQ
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift64(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift64)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

////=================================================================

#ifdef HEXL_HAS_AVX512IFMA
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift52(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

////=================================================================

#ifdef HEXL_HAS_AVX512IFMA
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift52GT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 = GenerateInsecureUniformRandomValues(
input_size, 4503599627370496, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52GT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});

static void BM_EltwiseReduceModAVX512BitShift52LT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 2251799813685248);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52LT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

//=================================================================

} // namespace hexl
} // namespace intel
71 changes: 22 additions & 49 deletions hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,36 @@

#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp"

#include <immintrin.h>
#include <stdint.h>

#include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/check.hpp"
#include "util/avx512-util.hpp"
#include "hexl/util/util.hpp"

namespace intel {
namespace hexl {

#ifdef HEXL_HAS_AVX512DQ
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t n, uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
HEXL_CHECK(n != 0, "Require n != 0")
HEXL_CHECK(modulus > 1, "Require modulus > 1");
HEXL_CHECK(diff != 0, "Require diff != 0");
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

uint64_t n_mod_8 = n % 8;
if (n_mod_8 != 0) {
EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
diff);
operand1 += n_mod_8;
result += n_mod_8;
n -= n_mod_8;
}
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
__m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
__m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
__m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
__m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));

uint64_t mu = MultiplyFactor(1, 64, modulus).BarrettFactor();
__m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));
/// @brief Computes element-wise conditional modular subtraction.
/// @param[out] result Stores the result
/// @param[in] operand1 Vector of elements to compare
/// @param[in] n Number of elements in \p operand1
/// @param[in] modulus Modulus to reduce by
/// @param[in] cmp Comparison function
/// @param[in] bound Scalar to compare against
/// @param[in] diff Scalar to subtract by
/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1

for (size_t i = n / 8; i > 0; --i) {
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));

v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_mu);

__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);
#ifdef HEXL_HAS_AVX512DQ
template void EltwiseCmpSubModAVX512<64>(uint64_t* result,
const uint64_t* operand1, uint64_t n,
uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
#endif

v_op = _mm512_add_epi64(v_op, v_to_add);
_mm512_storeu_si512(v_result_ptr, v_op);
++v_op_ptr;
++v_result_ptr;
}
}
#ifdef HEXL_HAS_AVX512IFMA
template void EltwiseCmpSubModAVX512<52>(uint64_t* result,
const uint64_t* operand1, uint64_t n,
uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
#endif

} // namespace hexl
Expand Down
70 changes: 58 additions & 12 deletions hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,72 @@

#pragma once

#include <immintrin.h>
#include <stdint.h>

#include "hexl/util/util.hpp"
#include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/check.hpp"
#include "util/avx512-util.hpp"

namespace intel {
namespace hexl {

/// @brief Computes element-wise conditional modular subtraction.
/// @param[out] result Stores the result
/// @param[in] operand1 Vector of elements to compare
/// @param[in] n Number of elements in \p operand1
/// @param[in] modulus Modulus to reduce by
/// @param[in] cmp Comparison function
/// @param[in] bound Scalar to compare against
/// @param[in] diff Scalar to subtract by
/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1
#ifdef HEXL_HAS_AVX512DQ
template <int BitShift = 64>
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t n, uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
uint64_t bound, uint64_t diff) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
HEXL_CHECK(n != 0, "Require n != 0")
HEXL_CHECK(modulus > 1, "Require modulus > 1");
HEXL_CHECK(diff != 0, "Require diff != 0");
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

uint64_t n_mod_8 = n % 8;
if (n_mod_8 != 0) {
EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
diff);
operand1 += n_mod_8;
result += n_mod_8;
n -= n_mod_8;
}
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
__m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
__m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
__m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
__m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));

uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor();
__m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));

// Multi-word Barrett reduction precomputation
constexpr int64_t beta = -2;
const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2
uint64_t prod_right_shift = ceil_log_mod + beta;
__m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));

for (size_t i = n / 8; i > 0; --i) {
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));

v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod);

__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);

v_op = _mm512_add_epi64(v_op, v_to_add);
_mm512_storeu_si512(v_result_ptr, v_op);
++v_op_ptr;
++v_result_ptr;
}
}
#endif

} // namespace hexl
} // namespace intel
9 changes: 8 additions & 1 deletion hexl/eltwise/eltwise-cmp-sub-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@ void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n,

#ifdef HEXL_HAS_AVX512DQ
if (has_avx512dq) {
EltwiseCmpSubModAVX512(result, operand1, n, modulus, cmp, bound, diff);
if (modulus < (1ULL << 52)) {
EltwiseCmpSubModAVX512<52>(result, operand1, n, modulus, cmp, bound,
diff);
} else {
EltwiseCmpSubModAVX512<64>(result, operand1, n, modulus, cmp, bound,
diff);
}
return;
}
#endif
EltwiseCmpSubModNative(result, operand1, n, modulus, cmp, bound, diff);
return;
}

void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1,
Expand Down
2 changes: 2 additions & 0 deletions hexl/eltwise/eltwise-mult-mod-internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <stdint.h>

#include <cmath>

#include "eltwise/eltwise-mult-mod-internal.hpp"
Expand Down
Loading