Skip to content

Commit

Permalink
ReduceMod optimization (#75)
Browse files Browse the repository at this point in the history
* optimize modular reduction

* remove MinTime

* update condition

* add unit test

* test algorithm 2 barrett reduction

* use different optimization for reduce mod based on input size

* add more test and benchmark

* update include

* fix formatting

* fix formatting

* rename arg, better perf

* fix windows build

* fix IFMA52 unit-tests

* fix debug benchmark

* avoid unused error

* update modulo

* update test

* simplify test
  • Loading branch information
GelilaSeifu authored and fboemer committed Nov 8, 2021
1 parent a0b9029 commit b394857
Show file tree
Hide file tree
Showing 16 changed files with 545 additions and 243 deletions.
121 changes: 114 additions & 7 deletions benchmark/bench-eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static void BM_EltwiseReduceModInPlace(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
for (auto _ : state) {
EltwiseReduceMod(input1.data(), input1.data(), input_size, modulus,
Expand All @@ -46,7 +46,7 @@ static void BM_EltwiseReduceModCopy(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

Expand All @@ -71,7 +71,7 @@ static void BM_EltwiseReduceModNative(benchmark::State& state) { // NOLINT

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

Expand All @@ -93,17 +93,17 @@ BENCHMARK(BM_EltwiseReduceModNative)
// state[0] is the degree
static void BM_EltwiseReduceModAVX512(benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 1152921504606877697;
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = 0;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 1;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512(output.data(), input1.data(), input_size, modulus,
input_mod_factor, output_mod_factor);
EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

Expand All @@ -116,5 +116,112 @@ BENCHMARK(BM_EltwiseReduceModAVX512)

//=================================================================

#ifdef HEXL_HAS_AVX512DQ
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift64(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift64)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

////=================================================================

#ifdef HEXL_HAS_AVX512IFMA
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift52(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

////=================================================================

#ifdef HEXL_HAS_AVX512IFMA
// state[0] is the degree
static void BM_EltwiseReduceModAVX512BitShift52GT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 = GenerateInsecureUniformRandomValues(
input_size, 4503599627370496, 100 * modulus);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52GT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});

static void BM_EltwiseReduceModAVX512BitShift52LT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
size_t modulus = 0xffffffffffc0001ULL;

auto input1 =
GenerateInsecureUniformRandomValues(input_size, 0, 2251799813685248);
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size,
modulus, input_mod_factor, output_mod_factor);
}
}

BENCHMARK(BM_EltwiseReduceModAVX512BitShift52LT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});
#endif

//=================================================================

} // namespace hexl
} // namespace intel
71 changes: 22 additions & 49 deletions hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,36 @@

#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp"

#include <immintrin.h>
#include <stdint.h>

#include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/check.hpp"
#include "util/avx512-util.hpp"
#include "hexl/util/util.hpp"

namespace intel {
namespace hexl {

#ifdef HEXL_HAS_AVX512DQ
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t n, uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
HEXL_CHECK(n != 0, "Require n != 0")
HEXL_CHECK(modulus > 1, "Require modulus > 1");
HEXL_CHECK(diff != 0, "Require diff != 0");
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

uint64_t n_mod_8 = n % 8;
if (n_mod_8 != 0) {
EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
diff);
operand1 += n_mod_8;
result += n_mod_8;
n -= n_mod_8;
}
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
__m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
__m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
__m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
__m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));

uint64_t mu = MultiplyFactor(1, 64, modulus).BarrettFactor();
__m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));
/// @brief Computes element-wise conditional modular subtraction.
/// @param[out] result Stores the result
/// @param[in] operand1 Vector of elements to compare
/// @param[in] n Number of elements in \p operand1
/// @param[in] modulus Modulus to reduce by
/// @param[in] cmp Comparison function
/// @param[in] bound Scalar to compare against
/// @param[in] diff Scalar to subtract by
/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1

for (size_t i = n / 8; i > 0; --i) {
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));

v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_mu);

__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);
#ifdef HEXL_HAS_AVX512DQ
template void EltwiseCmpSubModAVX512<64>(uint64_t* result,
const uint64_t* operand1, uint64_t n,
uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
#endif

v_op = _mm512_add_epi64(v_op, v_to_add);
_mm512_storeu_si512(v_result_ptr, v_op);
++v_op_ptr;
++v_result_ptr;
}
}
#ifdef HEXL_HAS_AVX512IFMA
template void EltwiseCmpSubModAVX512<52>(uint64_t* result,
const uint64_t* operand1, uint64_t n,
uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
#endif

} // namespace hexl
Expand Down
70 changes: 58 additions & 12 deletions hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,72 @@

#pragma once

#include <immintrin.h>
#include <stdint.h>

#include "hexl/util/util.hpp"
#include "eltwise/eltwise-cmp-sub-mod-internal.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "hexl/util/check.hpp"
#include "util/avx512-util.hpp"

namespace intel {
namespace hexl {

/// @brief Computes element-wise conditional modular subtraction.
/// @param[out] result Stores the result
/// @param[in] operand1 Vector of elements to compare
/// @param[in] n Number of elements in \p operand1
/// @param[in] modulus Modulus to reduce by
/// @param[in] cmp Comparison function
/// @param[in] bound Scalar to compare against
/// @param[in] diff Scalar to subtract by
/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p
/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1
#ifdef HEXL_HAS_AVX512DQ
template <int BitShift = 64>
void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1,
uint64_t n, uint64_t modulus, CMPINT cmp,
uint64_t bound, uint64_t diff);
uint64_t bound, uint64_t diff) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr");
HEXL_CHECK(n != 0, "Require n != 0")
HEXL_CHECK(modulus > 1, "Require modulus > 1");
HEXL_CHECK(diff != 0, "Require diff != 0");
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

uint64_t n_mod_8 = n % 8;
if (n_mod_8 != 0) {
EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound,
diff);
operand1 += n_mod_8;
result += n_mod_8;
n -= n_mod_8;
}
HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus);

const __m512i* v_op_ptr = reinterpret_cast<const __m512i*>(operand1);
__m512i* v_result_ptr = reinterpret_cast<__m512i*>(result);
__m512i v_bound = _mm512_set1_epi64(static_cast<int64_t>(bound));
__m512i v_diff = _mm512_set1_epi64(static_cast<int64_t>(diff));
__m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));

uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor();
__m512i v_mu = _mm512_set1_epi64(static_cast<int64_t>(mu));

// Multi-word Barrett reduction precomputation
constexpr int64_t beta = -2;
const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2
uint64_t prod_right_shift = ceil_log_mod + beta;
__m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));

for (size_t i = n / 8; i > 0; --i) {
__m512i v_op = _mm512_loadu_si512(v_op_ptr);
__mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp));

v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod);

__m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus);
v_to_add = _mm512_sub_epi64(v_to_add, v_diff);
v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0);

v_op = _mm512_add_epi64(v_op, v_to_add);
_mm512_storeu_si512(v_result_ptr, v_op);
++v_op_ptr;
++v_result_ptr;
}
}
#endif

} // namespace hexl
} // namespace intel
9 changes: 8 additions & 1 deletion hexl/eltwise/eltwise-cmp-sub-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@ void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n,

#ifdef HEXL_HAS_AVX512DQ
if (has_avx512dq) {
EltwiseCmpSubModAVX512(result, operand1, n, modulus, cmp, bound, diff);
if (modulus < (1ULL << 52)) {
EltwiseCmpSubModAVX512<52>(result, operand1, n, modulus, cmp, bound,
diff);
} else {
EltwiseCmpSubModAVX512<64>(result, operand1, n, modulus, cmp, bound,
diff);
}
return;
}
#endif
EltwiseCmpSubModNative(result, operand1, n, modulus, cmp, bound, diff);
return;
}

void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1,
Expand Down
2 changes: 2 additions & 0 deletions hexl/eltwise/eltwise-mult-mod-internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <stdint.h>

#include <cmath>

#include "eltwise/eltwise-mult-mod-internal.hpp"
Expand Down
Loading

0 comments on commit b394857

Please sign in to comment.