From e4c137516d9549b69c6f763dc03a541e6cbdb355 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Tue, 14 Sep 2021 16:35:13 -0700 Subject: [PATCH 01/18] optimize modular reduction --- benchmark/bench-eltwise-reduce-mod.cpp | 64 +++++++++- hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp | 71 ++++------- hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp | 63 ++++++++-- hexl/eltwise/eltwise-cmp-sub-mod.cpp | 9 +- hexl/eltwise/eltwise-reduce-mod-avx512.cpp | 109 ++--------------- hexl/eltwise/eltwise-reduce-mod-avx512.hpp | 112 +++++++++++++++++- hexl/eltwise/eltwise-reduce-mod-internal.hpp | 6 +- hexl/eltwise/eltwise-reduce-mod.cpp | 79 +++++++----- .../hexl/eltwise/eltwise-reduce-mod.hpp | 6 +- .../hexl/number-theory/number-theory.hpp | 10 +- hexl/include/hexl/util/msvc.hpp | 6 +- hexl/util/avx512-util.hpp | 34 ++++-- test/test-eltwise-reduce-mod-avx512.cpp | 10 +- test/test-eltwise-reduce-mod.cpp | 2 +- 14 files changed, 361 insertions(+), 220 deletions(-) diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index 83420a91..a6554471 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -23,7 +23,7 @@ static void BM_EltwiseReduceModInPlace(benchmark::State& state) { // NOLINT auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; for (auto _ : state) { EltwiseReduceMod(input1.data(), input1.data(), input_size, modulus, @@ -46,7 +46,7 @@ static void BM_EltwiseReduceModCopy(benchmark::State& state) { // NOLINT auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; AlignedVector64 output(input_size, 0); @@ -71,7 +71,7 @@ static void BM_EltwiseReduceModNative(benchmark::State& state) { // NOLINT auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; AlignedVector64 output(input_size, 0); @@ -97,7 +97,7 @@ static void BM_EltwiseReduceModAVX512(benchmark::State& state) { // NOLINT auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; AlignedVector64 output(input_size, 0); @@ -116,5 +116,61 @@ BENCHMARK(BM_EltwiseReduceModAVX512) //================================================================= +#ifdef HEXL_HAS_AVX512DQ +// state[0] is the degree +static void BM_EltwiseReduceModAVX512BitShift64( + benchmark::State& state) { // NOLINT + size_t input_size = state.range(0); + size_t modulus = 1152921504606877697; + + AlignedVector64 input1(input_size, 1); + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + AlignedVector64 output(input_size, 0); + + for (auto _ : state) { + EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + } +} + +BENCHMARK(BM_EltwiseReduceModAVX512BitShift64) + ->Unit(benchmark::kMicrosecond) + ->MinTime(1.0) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); +#endif + +//================================================================= + +#ifdef HEXL_HAS_AVX512IFMA +// state[0] is the degree +static void BM_EltwiseReduceModAVX512BitShift52( + benchmark::State& state) { // NOLINT + size_t input_size = state.range(0); + size_t modulus = 1152921504606877697; + + AlignedVector64 input1(input_size, 1); + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + AlignedVector64 output(input_size, 0); + + for (auto _ : state) { + EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + } +} + +BENCHMARK(BM_EltwiseReduceModAVX512BitShift52) + ->Unit(benchmark::kMicrosecond) + ->MinTime(1.0) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); +#endif + +//================================================================= + } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp index 8f87b690..052d603c 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.cpp @@ -3,63 +3,36 @@ #include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" -#include #include -#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" -#include "hexl/number-theory/number-theory.hpp" -#include "hexl/util/check.hpp" -#include "util/avx512-util.hpp" +#include "hexl/util/util.hpp" namespace intel { namespace hexl { -#ifdef HEXL_HAS_AVX512DQ -void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, - uint64_t n, uint64_t modulus, CMPINT cmp, - uint64_t bound, uint64_t diff) { - HEXL_CHECK(result != nullptr, "Require result != nullptr"); - HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); - HEXL_CHECK(n != 0, "Require n != 0") - HEXL_CHECK(modulus > 1, "Require modulus > 1"); - HEXL_CHECK(diff != 0, "Require diff != 0"); - HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); - - uint64_t n_mod_8 = n % 8; - if (n_mod_8 != 0) { - EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound, - diff); - operand1 += n_mod_8; - result += n_mod_8; - n -= n_mod_8; - } - HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); - - const __m512i* v_op_ptr = reinterpret_cast(operand1); - __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); - __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); - __m512i v_diff = _mm512_set1_epi64(static_cast(diff)); - __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); - - uint64_t mu = MultiplyFactor(1, 64, modulus).BarrettFactor(); - __m512i v_mu = _mm512_set1_epi64(static_cast(mu)); +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 - for (size_t i = n / 8; i > 0; --i) { - __m512i v_op = _mm512_loadu_si512(v_op_ptr); - __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); - - v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_mu); - - __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); - v_to_add = _mm512_sub_epi64(v_to_add, v_diff); - v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0); +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseCmpSubModAVX512<64>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); +#endif - v_op = _mm512_add_epi64(v_op, v_to_add); - _mm512_storeu_si512(v_result_ptr, v_op); - ++v_op_ptr; - ++v_result_ptr; - } -} +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseCmpSubModAVX512<52>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); #endif } // namespace hexl diff --git a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp index 33768888..4cf35968 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -3,26 +3,65 @@ #pragma once +#include #include -#include "hexl/util/util.hpp" +#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" namespace intel { namespace hexl { -/// @brief Computes element-wise conditional modular subtraction. -/// @param[out] result Stores the result -/// @param[in] operand1 Vector of elements to compare -/// @param[in] n Number of elements in \p operand1 -/// @param[in] modulus Modulus to reduce by -/// @param[in] cmp Comparison function -/// @param[in] bound Scalar to compare against -/// @param[in] diff Scalar to subtract by -/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p -/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, uint64_t modulus, CMPINT cmp, - uint64_t bound, uint64_t diff); + uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound, + diff); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + const __m512i* v_op_ptr = reinterpret_cast(operand1); + __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); + __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); + __m512i v_diff = _mm512_set1_epi64(static_cast(diff)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + + uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor(); + __m512i v_mu = _mm512_set1_epi64(static_cast(mu)); + + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op = _mm512_loadu_si512(v_op_ptr); + __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); + + v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_mu); + + __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); + v_to_add = _mm512_sub_epi64(v_to_add, v_diff); + v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0); + + v_op = _mm512_add_epi64(v_op, v_to_add); + _mm512_storeu_si512(v_result_ptr, v_op); + ++v_op_ptr; + ++v_result_ptr; + } +} +#endif } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-cmp-sub-mod.cpp b/hexl/eltwise/eltwise-cmp-sub-mod.cpp index c8c0bd04..96074502 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod.cpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod.cpp @@ -26,11 +26,18 @@ void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, #ifdef HEXL_HAS_AVX512DQ if (has_avx512dq) { - EltwiseCmpSubModAVX512(result, operand1, n, modulus, cmp, bound, diff); + if (modulus < (1ULL << 52)) { + EltwiseCmpSubModAVX512<52>(result, operand1, n, modulus, cmp, bound, + diff); + } else { + EltwiseCmpSubModAVX512<64>(result, operand1, n, modulus, cmp, bound, + diff); + } return; } #endif EltwiseCmpSubModNative(result, operand1, n, modulus, cmp, bound, diff); + return; } void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.cpp b/hexl/eltwise/eltwise-reduce-mod-avx512.cpp index e27209ac..40128085 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.cpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.cpp @@ -3,109 +3,24 @@ #include "eltwise/eltwise-reduce-mod-avx512.hpp" -#include -#include - -#include "eltwise/eltwise-reduce-mod-internal.hpp" -#include "hexl/eltwise/eltwise-reduce-mod.hpp" -#include "hexl/logging/logging.hpp" -#include "hexl/number-theory/number-theory.hpp" -#include "hexl/util/check.hpp" -#include "util/avx512-util.hpp" - namespace intel { namespace hexl { #ifdef HEXL_HAS_AVX512DQ -void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, - uint64_t n, uint64_t modulus, - uint64_t input_mod_factor, - uint64_t output_mod_factor) { - HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); - HEXL_CHECK(n != 0, "Require n != 0"); - HEXL_CHECK(modulus > 1, "Require modulus > 1"); - HEXL_CHECK( - input_mod_factor == 0 || input_mod_factor == 2 || input_mod_factor == 4, - "input_mod_factor must be 0 or 2 or 4" << input_mod_factor); - HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, - "output_mod_factor must be 1 or 2 " << output_mod_factor); - HEXL_CHECK(input_mod_factor != output_mod_factor, - "input_mod_factor must not be equal to output_mod_factor "); - - uint64_t n_tmp = n; - uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); - __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); - - // Deals with n not divisible by 8 - uint64_t n_mod_8 = n_tmp % 8; - if (n_mod_8 != 0) { - EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor, - output_mod_factor); - operand += n_mod_8; - result += n_mod_8; - n_tmp -= n_mod_8; - } - - uint64_t twice_mod = modulus << 1; - const __m512i* v_operand = reinterpret_cast(operand); - __m512i* v_result = reinterpret_cast<__m512i*>(result); - __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); - __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); - - switch (input_mod_factor) { - case 0: - for (size_t i = 0; i < n_tmp; i += 8) { - __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_bf); - HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, - "v_op exceeds bound " << modulus); - _mm512_storeu_si512(v_result, v_op); - ++v_operand; - ++v_result; - } - break; - - case 2: - for (size_t i = 0; i < n_tmp; i += 8) { - __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); - HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, - "v_op exceeds bound " << modulus); - _mm512_storeu_si512(v_result, v_op); - ++v_operand; - ++v_result; - } - break; - - case 4: - if (output_mod_factor == 1) { - for (size_t i = 0; i < n_tmp; i += 8) { - __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); - v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); - HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, - "v_op exceeds bound " << modulus); - _mm512_storeu_si512(v_result, v_op); - ++v_operand; - ++v_result; - } - } - if (output_mod_factor == 2) { - for (size_t i = 0; i < n_tmp; i += 8) { - __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); - HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod, - "v_op exceeds bound " << twice_mod); - _mm512_storeu_si512(v_result, v_op); - ++v_operand; - ++v_result; - } - } - break; - } -} +template void EltwiseReduceModAVX512<64>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +#endif +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseReduceModAVX512<52>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); #endif } // namespace hexl diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp index 939dfc50..3092072f 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp @@ -3,14 +3,122 @@ #pragma once -#include +#include +#include + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" +#include "eltwise/eltwise-reduce-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" namespace intel { namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, uint64_t input_mod_factor, - uint64_t output_mod_factor); + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + HEXL_CHECK(input_mod_factor != output_mod_factor, + "input_mod_factor must not be equal to output_mod_factor "); + + uint64_t n_tmp = n; + uint64_t barrett_factor = + MultiplyFactor(1, BitShift, modulus).BarrettFactor(); + __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor, + output_mod_factor); + operand += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + uint64_t twice_mod = modulus << 1; + const __m512i* v_operand = reinterpret_cast(operand); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_bf); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } else { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_bf); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } + + if (input_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod, + "v_op exceeds bound " << twice_mod); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } +} + +#endif } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-reduce-mod-internal.hpp b/hexl/eltwise/eltwise-reduce-mod-internal.hpp index 6f67e387..7f23a5c1 100644 --- a/hexl/eltwise/eltwise-reduce-mod-internal.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-internal.hpp @@ -12,9 +12,9 @@ namespace hexl { // @param[in] n Number of elements in operand // @param[in] modulus Modulus with which to perform modular reduction // @param[in] input_mod_factor Assumes input elements are in [0, -// input_mod_factor * p) Must be 0, 2 or 4. input_mod_factor=0 means, no -// knowledge of input range. Barrett reduction will be used in this case -// input_mod_factor > output_mod_factor unless input_mod_factor == 0 +// input_mod_factor * p) Must be modulus, 2 or 4. input_mod_factor=modulus +// means, input range is [0, p * p]. Barrett reduction will be used in this case +// input_mod_factor > output_mod_factor // @param[in] output_mod_factor output elements will be in [0, output_mod_factor // * p) Must be 1 or 2. for input_mod_factor=0, output_mod_factor will be set // to 1. diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index 941ee431..c522eaa1 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -21,9 +21,9 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, HEXL_CHECK(result != nullptr, "Require result != nullptr"); HEXL_CHECK(n != 0, "Require n != 0"); HEXL_CHECK(modulus > 1, "Require modulus > 1"); - HEXL_CHECK( - input_mod_factor == 0 || input_mod_factor == 2 || input_mod_factor == 4, - "input_mod_factor must be 0 or 2 or 4" << input_mod_factor); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, "output_mod_factor must be 1 or 2 " << output_mod_factor); HEXL_CHECK(input_mod_factor != output_mod_factor, @@ -32,37 +32,49 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); uint64_t twice_modulus = modulus << 1; - switch (input_mod_factor) { - case 0: + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { for (size_t i = 0; i < n; ++i) { - result[i] = BarrettReduce64(operand[i], modulus, barrett_factor); + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } } - HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); - break; - - case 2: + } else { for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<2>(operand[i], modulus); + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); - break; + } + } - case 4: - if (output_mod_factor == 1) { - for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); - } - HEXL_CHECK_BOUNDS(result, n, modulus, - "result exceeds bound " << modulus); + if (input_mod_factor == 2) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], modulus); + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); } - if (output_mod_factor == 2) { - for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<2>(operand[i], twice_modulus); - } - HEXL_CHECK_BOUNDS(result, n, twice_modulus, - "result exceeds bound " << twice_modulus); + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + if (output_mod_factor == 2) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], twice_modulus); } - break; + HEXL_CHECK_BOUNDS(result, n, twice_modulus, + "result exceeds bound " << twice_modulus); + } } } @@ -73,9 +85,9 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, HEXL_CHECK(result != nullptr, "Require result != nullptr"); HEXL_CHECK(n != 0, "Require n != 0"); HEXL_CHECK(modulus > 1, "Require modulus > 1"); - HEXL_CHECK( - input_mod_factor == 0 || input_mod_factor == 2 || input_mod_factor == 4, - "input_mod_factor must be 0 or 2 or 4" << input_mod_factor); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, "output_mod_factor must be 1 or 2 " << output_mod_factor); @@ -87,8 +99,13 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, } #ifdef HEXL_HAS_AVX512DQ if (has_avx512dq) { - EltwiseReduceModAVX512(result, operand, n, modulus, input_mod_factor, - output_mod_factor); + if (modulus < (1ULL << 52)) { + EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + } else { + EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + } return; } #endif diff --git a/hexl/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl/include/hexl/eltwise/eltwise-reduce-mod.hpp index 2cbf9cd0..1341042d 100644 --- a/hexl/include/hexl/eltwise/eltwise-reduce-mod.hpp +++ b/hexl/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -15,9 +15,9 @@ namespace hexl { /// @param[in] n Number of elements in operand /// @param[in] modulus Modulus with which to perform modular reduction /// @param[in] input_mod_factor Assumes input elements are in [0, -/// input_mod_factor * p) Must be 0, 1, 2 or 4. input_mod_factor=0 means, no -/// knowledge of input range. Barrett reduction will be used in this case. -/// input_mod_factor >= output_mod_factor unless input_mod_factor == 0 +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor /// @param[in] output_mod_factor output elements will be in [0, /// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, /// output_mod_factor will be set to 1. diff --git a/hexl/include/hexl/number-theory/number-theory.hpp b/hexl/include/hexl/number-theory/number-theory.hpp index f5e2c4f9..355efb1a 100644 --- a/hexl/include/hexl/number-theory/number-theory.hpp +++ b/hexl/include/hexl/number-theory/number-theory.hpp @@ -192,12 +192,16 @@ std::vector GeneratePrimes(size_t num_primes, size_t bit_size, /// @param[in] input /// @param[in] modulus /// @param[in] q_barr floor(2^64 / modulus) -inline uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, - uint64_t q_barr) { +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { HEXL_CHECK(modulus != 0, "modulus == 0"); uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); uint64_t q_times_input = input - q * modulus; - return q_times_input >= modulus ? q_times_input - modulus : q_times_input; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } } /// @brief Returns x mod modulus, assuming x < InputModFactor * modulus diff --git a/hexl/include/hexl/util/msvc.hpp b/hexl/include/hexl/util/msvc.hpp index 06654d43..291e9549 100644 --- a/hexl/include/hexl/util/msvc.hpp +++ b/hexl/include/hexl/util/msvc.hpp @@ -273,9 +273,13 @@ inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, // Returns most-significant bit of the input inline uint64_t MSB(uint64_t input) { + HEXL_CHECK(input == 0, "input cannot be 0: Got " << input); unsigned long index{0}; // NOLINT(runtime/int) _BitScanReverse64(&index, input); - return index; + if (index >= 0 && input > 0) { + return static_cast(index); + } + return 0; } #define HEXL_LOOP_UNROLL_4 \ diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 309844ff..5d11b3b4 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -371,18 +371,36 @@ inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, // Returns x mod q, computed via Barrett reduction // @param q_barr floor(2^BitShift / q) -template +template inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, __m512i q_barr) { - __m512i rnd1_hi = _mm512_hexl_mulhi_epi(x, q_barr); + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + +#ifdef HEXL_HAS_AVX512IFMA + if (BitShift == 52) { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); + // Barrett subtraction + // tmp[0] = input - tmp[1] * q; + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } +#endif + if (BitShift == 64) { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr); + // Barrett subtraction + // tmp[0] = input - tmp[1] * q; + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } - // Barrett subtraction - // tmp[0] = input - tmp[1] * q; - __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); - x = _mm512_sub_epi64(x, tmp1_times_mod); // Correction - x = _mm512_hexl_small_mod_epu64(x, q); - return x; + if (OutputModFactor == 2) { + return x; + } else { + x = _mm512_hexl_small_mod_epu64(x, q); + return x; + } } // Concatenate packed 64-bit integers in x and y, producing an intermediate diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index fd613a9e..1179328b 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -28,7 +28,7 @@ TEST(EltwiseReduceMod, avx512_0_1) { std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; uint64_t modulus = 769; - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; EltwiseReduceModAVX512(result.data(), op.data(), op.size(), modulus, input_mod_factor, output_mod_factor); @@ -113,10 +113,10 @@ TEST(EltwiseReduceMod, AVX512Big_0_1) { std::vector result1(length, 0); std::vector result2(length, 0); - EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, 0, - 1); - EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, 0, - 1); + EltwiseReduceModNative(result1.data(), op1.data(), op1.size(), modulus, + modulus, 1); + EltwiseReduceModAVX512(result2.data(), op2.data(), op1.size(), modulus, + modulus, 1); ASSERT_EQ(result1, result2); ASSERT_EQ(result1, result2); diff --git a/test/test-eltwise-reduce-mod.cpp b/test/test-eltwise-reduce-mod.cpp index bc75ed15..9e5480ef 100644 --- a/test/test-eltwise-reduce-mod.cpp +++ b/test/test-eltwise-reduce-mod.cpp @@ -46,7 +46,7 @@ TEST(EltwiseReduceMod, 0_1) { std::vector result{0, 0, 0, 0}; const uint64_t modulus = 750; - const uint64_t input_mod_factor = 0; + const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; EltwiseReduceMod(result.data(), op.data(), op.size(), modulus, input_mod_factor, output_mod_factor); From ef3a689ce2147c155869d22fe21b1f0bae49e828 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Tue, 14 Sep 2021 16:38:32 -0700 Subject: [PATCH 02/18] remove MinTime --- benchmark/bench-eltwise-reduce-mod.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index a6554471..753af575 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -136,7 +136,6 @@ static void BM_EltwiseReduceModAVX512BitShift64( BENCHMARK(BM_EltwiseReduceModAVX512BitShift64) ->Unit(benchmark::kMicrosecond) - ->MinTime(1.0) ->Args({1024}) ->Args({4096}) ->Args({16384}); @@ -164,7 +163,6 @@ static void BM_EltwiseReduceModAVX512BitShift52( BENCHMARK(BM_EltwiseReduceModAVX512BitShift52) ->Unit(benchmark::kMicrosecond) - ->MinTime(1.0) ->Args({1024}) ->Args({4096}) ->Args({16384}); From 370296d3202a5cb92162fcaaec33c3062640753f Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Wed, 15 Sep 2021 10:45:03 -0700 Subject: [PATCH 03/18] update condition --- hexl/eltwise/eltwise-reduce-mod.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index c522eaa1..aebdf805 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -44,7 +44,7 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, } else { for (size_t i = 0; i < n; ++i) { if (operand[i] >= modulus) { - result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + result[i] = BarrettReduce64<1>(operand[i], modulus, barrett_factor); } else { result[i] = operand[i]; } From c4697193fd6002940d57f84c16b383705d3255f0 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Thu, 23 Sep 2021 12:50:19 -0700 Subject: [PATCH 04/18] add unit test --- test/test-eltwise-reduce-mod-avx512.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 1179328b..dea805ee 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -86,6 +86,28 @@ TEST(EltwiseReduceMod, avx512_4_2) { CheckEqual(result, exp_out); } +TEST(EltwiseReduceMod, avx512_mod_1) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + std::vector op{914704788761805005, 224925333812073588, + 592788284123677125, 142439467624940029, + 146023272535470246, 979015887843024185, + 496780369302017539, 1073741441}; + std::vector exp_out{572243325, 955099389, 432045053, 160411261, + 815223709, 349526397, 10878205, 0}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + uint64_t modulus = 1073741441; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + + EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); +} + // Checks AVX512 and native EltwiseReduceMod implementations match with randomly // generated inputs TEST(EltwiseReduceMod, AVX512Big_0_1) { From e12e734dfcf03d103b4b99deb7194ef438f5c6a7 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Tue, 28 Sep 2021 09:04:26 -0700 Subject: [PATCH 05/18] test algorithm 2 barrett reduction --- hexl/util/avx512-util.hpp | 43 ++++++++++++++++++++++--- test/test-eltwise-reduce-mod-avx512.cpp | 4 +-- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 5d11b3b4..b1bf05f1 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -379,11 +379,44 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, #ifdef HEXL_HAS_AVX512IFMA if (BitShift == 52) { - __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); - // Barrett subtraction - // tmp[0] = input - tmp[1] * q; - __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); - x = _mm512_sub_epi64(x, tmp1_times_mod); + int64_t beta = -2; + uint64_t mod = ExtractValues(q)[0]; + // std::cout << "52 bit mod: " << mod << std::endl; + uint64_t ceil_log_mod = Log2(mod) + 1; + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(mod)); + // std::cout << "52 prod_right_shift: " << prod_right_shift << std::endl; + __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); + __m512i x_intr = _mm512_slli_epi64(x, static_cast(12ULL)); + __m512i x_lo = _mm512_srli_epi64(x_intr, static_cast(12ULL)); + + uint64_t input = ExtractValues(x)[0]; + std::cout << "input: " << input << std::endl; + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(x_lo, static_cast(prod_right_shift)); + __m512i c1_hi = _mm512_slli_epi64( + x_hi, static_cast(52ULL - (prod_right_shift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr); + + // Z = prod_lo - (p * q_hat)_lo + x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); + x = _mm512_hexl_small_mod_epu64<2>(x, q); + + uint64_t output = ExtractValues(x)[0]; + std::cout << "output: " << output << std::endl; + + return x; + /* + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); + // Barrett subtraction + // tmp[0] = input - tmp[1] * q; + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod);*/ } #endif if (BitShift == 64) { diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index dea805ee..6eebb33f 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -95,8 +95,8 @@ TEST(EltwiseReduceMod, avx512_mod_1) { 592788284123677125, 142439467624940029, 146023272535470246, 979015887843024185, 496780369302017539, 1073741441}; - std::vector exp_out{572243325, 955099389, 432045053, 160411261, - 815223709, 349526397, 10878205, 0}; + std::vector exp_out{802487803, 754009873, 962097738, 36142730, + 687617508, 519876583, 630345322, 0}; std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; uint64_t modulus = 1073741441; From a18357e6139b07549340a82b4e5b8fd065a075be Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Thu, 7 Oct 2021 16:33:48 -0700 Subject: [PATCH 06/18] use different optimization for reduce mod based on input size --- benchmark/bench-eltwise-reduce-mod.cpp | 52 +++++++++ hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp | 9 +- hexl/eltwise/eltwise-reduce-mod-avx512.hpp | 27 ++++- hexl/util/avx512-util.hpp | 114 +++++++++++++------- test/test-avx512-util.cpp | 77 +++++++++---- test/test-eltwise-reduce-mod-avx512.cpp | 8 +- 6 files changed, 218 insertions(+), 69 deletions(-) diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index 753af575..d47991e9 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -170,5 +170,57 @@ BENCHMARK(BM_EltwiseReduceModAVX512BitShift52) //================================================================= +#ifdef HEXL_HAS_AVX512IFMA +// state[0] is the degree +static void BM_EltwiseReduceModAVX512BitShift52GT( + benchmark::State& state) { // NOLINT + size_t input_size = state.range(0); + size_t modulus = 1152921504606877697; + + // AlignedVector64 input1(input_size, 1); + auto input1 = GenerateInsecureUniformRandomValues( + input_size, 4503599627370496, 100 * modulus); + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + AlignedVector64 output(input_size, 0); + + for (auto _ : state) { + EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + } +} + +BENCHMARK(BM_EltwiseReduceModAVX512BitShift52GT) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); + +static void BM_EltwiseReduceModAVX512BitShift52LT( + benchmark::State& state) { // NOLINT + size_t input_size = state.range(0); + size_t modulus = 1073741441; + + // AlignedVector64 input1(input_size, 1); + auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 2 * modulus); + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + AlignedVector64 output(input_size, 0); + + for (auto _ : state) { + EltwiseReduceModAVX512<52>(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + } +} + +BENCHMARK(BM_EltwiseReduceModAVX512BitShift52LT) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); +#endif + +//================================================================= + } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp index 4cf35968..194dcb3f 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -45,11 +45,18 @@ void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor(); __m512i v_mu = _mm512_set1_epi64(static_cast(mu)); + // Multi-word Barrett reduction precomputation + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + for (size_t i = n / 8; i > 0; --i) { __m512i v_op = _mm512_loadu_si512(v_op_ptr); __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); - v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_mu); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_mu, prod_right_shift, v_neg_mod); __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); v_to_add = _mm512_sub_epi64(v_to_add, v_diff); diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp index 3092072f..53d16cfd 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp @@ -35,8 +35,27 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, "input_mod_factor must not be equal to output_mod_factor "); uint64_t n_tmp = n; + + // Multi-word Barrett reduction precomputation + constexpr int64_t alpha = BitShift - 2; + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + uint64_t barrett_factor = - MultiplyFactor(1, BitShift, modulus).BarrettFactor(); + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + // Single-worded Barrett reduction. + uint64_t barrett_factor52 = MultiplyFactor(1, 52, modulus).BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); // Deals with n not divisible by 8 @@ -59,7 +78,8 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, if (output_mod_factor == 2) { for (size_t i = 0; i < n_tmp; i += 8) { __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_bf); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, prod_right_shift, v_neg_mod); HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_op); @@ -69,7 +89,8 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, } else { for (size_t i = 0; i < n_tmp; i += 8) { __m512i v_op = _mm512_loadu_si512(v_operand); - v_op = _mm512_hexl_barrett_reduce64(v_op, v_modulus, v_bf); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, prod_right_shift, v_neg_mod); HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_op); diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index b1bf05f1..70983759 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -369,54 +369,83 @@ inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value); } +//// Returns x mod q, computed via Barrett reduction +//// @param q_barr floor(2^BitShift / q) +// template +// inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, +// __m512i q_barr) { +// HEXL_CHECK(BitShift == 52 || BitShift == 64, +// "Invalid bitshift " << BitShift << "; need 52 or 64"); + +// #ifdef HEXL_HAS_AVX512IFMA +// if (BitShift == 52) { +// __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); +// // Barrett subtraction +// // tmp[0] = input - tmp[1] * q; +// __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); +// x = _mm512_sub_epi64(x, tmp1_times_mod); +// } +// #endif +// if (BitShift == 64) { +// __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr); +// // Barrett subtraction +// // tmp[0] = input - tmp[1] * q; +// __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); +// x = _mm512_sub_epi64(x, tmp1_times_mod); +// } + +// // Correction +// if (OutputModFactor == 2) { +// return x; +// } else { +// x = _mm512_hexl_small_mod_epu64(x, q); +// return x; +// } +//} + // Returns x mod q, computed via Barrett reduction // @param q_barr floor(2^BitShift / q) template inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, - __m512i q_barr) { + __m512i q_barr, + uint64_t prod_right_shift, + __m512i v_neg_mod) { HEXL_CHECK(BitShift == 52 || BitShift == 64, "Invalid bitshift " << BitShift << "; need 52 or 64"); #ifdef HEXL_HAS_AVX512IFMA if (BitShift == 52) { - int64_t beta = -2; - uint64_t mod = ExtractValues(q)[0]; - // std::cout << "52 bit mod: " << mod << std::endl; - uint64_t ceil_log_mod = Log2(mod) + 1; - uint64_t prod_right_shift = ceil_log_mod + beta; - __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(mod)); - // std::cout << "52 prod_right_shift: " << prod_right_shift << std::endl; - __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); - __m512i x_intr = _mm512_slli_epi64(x, static_cast(12ULL)); - __m512i x_lo = _mm512_srli_epi64(x_intr, static_cast(12ULL)); - - uint64_t input = ExtractValues(x)[0]; - std::cout << "input: " << input << std::endl; - - // c1 = floor(U / 2^{n + beta}) - __m512i c1_lo = - _mm512_srli_epi64(x_lo, static_cast(prod_right_shift)); - __m512i c1_hi = _mm512_slli_epi64( - x_hi, static_cast(52ULL - (prod_right_shift))); - __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); - - // alpha - beta == 52, so we only need high 52 bits - __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr); - - // Z = prod_lo - (p * q_hat)_lo - x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); - x = _mm512_hexl_small_mod_epu64<2>(x, q); - - uint64_t output = ExtractValues(x)[0]; - std::cout << "output: " << output << std::endl; - - return x; - /* - __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); - // Barrett subtraction - // tmp[0] = input - tmp[1] * q; - __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); - x = _mm512_sub_epi64(x, tmp1_times_mod);*/ + uint64_t match_value = 1; + __m512i two_pow_fiftytwo = _mm512_set1_epi64(4503599627370496); + __m512i mask = _mm512_hexl_cmpge_epu64(x, two_pow_fiftytwo, match_value); + uint64_t sum = _mm512_reduce_add_epi64(mask); + if (sum > 0) { + __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); + __m512i x_intr = _mm512_slli_epi64(x, static_cast(12ULL)); + __m512i x_lo = + _mm512_srli_epi64(x_intr, static_cast(12ULL)); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(x_lo, static_cast(prod_right_shift)); + __m512i c1_hi = _mm512_slli_epi64( + x_hi, static_cast(52ULL - (prod_right_shift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr); + + // Z = prod_lo - (p * q_hat)_lo + x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); + // x = _mm512_hexl_small_mod_epu64<2>(x, q); + + } else { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); + // Barrett subtraction + // tmp[0] = input - tmp[1] * q; + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } } #endif if (BitShift == 64) { @@ -431,7 +460,12 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, if (OutputModFactor == 2) { return x; } else { - x = _mm512_hexl_small_mod_epu64(x, q); + if (BitShift == 64) { + x = _mm512_hexl_small_mod_epu64(x, q); + } + if (BitShift == 52) { + x = _mm512_hexl_small_mod_epu64<2>(x, q); + } return x; } } diff --git a/test/test-avx512-util.cpp b/test/test-avx512-util.cpp index 33ae98eb..6dfdaf96 100644 --- a/test/test-avx512-util.cpp +++ b/test/test-avx512-util.cpp @@ -294,26 +294,54 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { } // Small - { - __m512i a = _mm512_set_epi64(12, 11, 10, 8, 6, 4, 2, 0); - - std::vector moduli{2, 2, 3, 4, 5, 6, 7, 8}; - std::vector barrs(moduli.size()); - for (size_t i = 0; i < barrs.size(); ++i) { - barrs[i] = MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); - } - - __m512i vmoduli = - _mm512_set_epi64(moduli[7], moduli[6], moduli[5], moduli[4], moduli[3], - moduli[2], moduli[1], moduli[0]); - __m512i vbarrs = _mm512_set_epi64(barrs[7], barrs[6], barrs[5], barrs[4], - barrs[3], barrs[2], barrs[1], barrs[0]); - - __m512i expected_out = _mm512_set_epi64(4, 4, 4, 3, 2, 1, 0, 0); - - __m512i c = _mm512_hexl_barrett_reduce64(a, vmoduli, vbarrs); - AssertEqual(c, expected_out); - } + // { + // __m512i a = _mm512_set_epi64(12, 11, 10, 8, 6, 4, 2, 0); + + // std::vector moduli{2, 2, 3, 4, 5, 6, 7, 8}; + // std::vector barrs(moduli.size()); + // for (size_t i = 0; i < barrs.size(); ++i) { + // barrs[i] = MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); + // } + + // // Multi-word Barrett reduction precomputation + // std::vector ceil_log_mod(moduli.size()); + // constexpr int64_t beta = -2; + // for (size_t i = 0; i < ceil_log_mod.size(); ++i) { + // ceil_log_mod[i] = Log2(moduli[i]) + 1; + // } + + // std::vector prod_right_shift(moduli.size()); + // for (size_t i = 0; i < prod_right_shift.size(); ++i) { + // prod_right_shift[i] = ceil_log_mod[i] + beta; + // } + + // //const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from + // Algorithm 2 + // // uint64_t prod_right_shift = ceil_log_mod + beta; + // __m512i v_neg_mod = + // _mm512_set_epi64(-static_cast(moduli[7]), + // -static_cast(moduli[6]), + // -static_cast(moduli[5]), + // -static_cast(moduli[4]), + // -static_cast(moduli[3]), + // -static_cast(moduli[2]), + // -static_cast(moduli[1]), + // -static_cast(moduli[0])); + + // __m512i vmoduli = + // _mm512_set_epi64(moduli[7], moduli[6], moduli[5], moduli[4], + // moduli[3], + // moduli[2], moduli[1], moduli[0]); + // __m512i vbarrs = _mm512_set_epi64(barrs[7], barrs[6], barrs[5], + // barrs[4], + // barrs[3], barrs[2], barrs[1], + // barrs[0]); + + // __m512i expected_out = _mm512_set_epi64(4, 4, 4, 3, 2, 1, 0, 0); + + // __m512i c = _mm512_hexl_barrett_reduce64(a, vmoduli, vbarrs, + // prod_right_shift, v_neg_mod); AssertEqual(c, expected_out); + // } // Random { @@ -322,6 +350,12 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { __m512i vbarr = _mm512_set1_epi64(MultiplyFactor(1, 64, modulus).BarrettFactor()); + // Multi-word Barrett reduction precomputation + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + for (size_t trial = 0; trial < 200; ++trial) { auto arg1 = GenerateInsecureUniformRandomValues(8, 0, modulus * modulus); auto exp = arg1; @@ -332,7 +366,8 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { __m512i varg1 = _mm512_set_epi64(arg1[7], arg1[6], arg1[5], arg1[4], arg1[3], arg1[2], arg1[1], arg1[0]); - __m512i c = _mm512_hexl_barrett_reduce64(varg1, vmodulus, vbarr); + __m512i c = _mm512_hexl_barrett_reduce64(varg1, vmodulus, vbarr, + prod_right_shift, v_neg_mod); std::vector result = ExtractValues(c); AssertEqual(result, exp); diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 6eebb33f..a66f695e 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -18,7 +18,7 @@ namespace intel { namespace hexl { #ifdef HEXL_HAS_AVX512DQ -TEST(EltwiseReduceMod, avx512_0_1) { +TEST(EltwiseReduceMod, avx512_mod_1) { if (!has_avx512dq) { GTEST_SKIP(); } @@ -30,8 +30,8 @@ TEST(EltwiseReduceMod, avx512_0_1) { uint64_t modulus = 769; const uint64_t input_mod_factor = modulus; const uint64_t output_mod_factor = 1; - EltwiseReduceModAVX512(result.data(), op.data(), op.size(), modulus, - input_mod_factor, output_mod_factor); + EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); CheckEqual(result, exp_out); } @@ -86,7 +86,7 @@ TEST(EltwiseReduceMod, avx512_4_2) { CheckEqual(result, exp_out); } -TEST(EltwiseReduceMod, avx512_mod_1) { +TEST(EltwiseReduceMod, avx512Big_mod_1) { if (!has_avx512dq) { GTEST_SKIP(); } From 8542603b5f00476e7abe36d46ddd30f184ab01db Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 05:41:58 -0700 Subject: [PATCH 07/18] add more test and benchmark --- benchmark/bench-eltwise-reduce-mod.cpp | 19 +++++----- hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp | 2 +- hexl/eltwise/eltwise-reduce-mod-avx512.hpp | 8 ++-- hexl/util/avx512-util.hpp | 42 ++------------------- test/test-avx512-util.cpp | 2 +- test/test-eltwise-reduce-mod-avx512.cpp | 19 +++++++++- 6 files changed, 37 insertions(+), 55 deletions(-) diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index d47991e9..48dfb3aa 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -123,9 +123,10 @@ static void BM_EltwiseReduceModAVX512BitShift64( size_t input_size = state.range(0); size_t modulus = 1152921504606877697; - AlignedVector64 input1(input_size, 1); + auto input1 = + GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); const uint64_t input_mod_factor = modulus; - const uint64_t output_mod_factor = 1; + const uint64_t output_mod_factor = 2; AlignedVector64 output(input_size, 0); for (auto _ : state) { @@ -150,9 +151,10 @@ static void BM_EltwiseReduceModAVX512BitShift52( size_t input_size = state.range(0); size_t modulus = 1152921504606877697; - AlignedVector64 input1(input_size, 1); + auto input1 = + GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); const uint64_t input_mod_factor = modulus; - const uint64_t output_mod_factor = 1; + const uint64_t output_mod_factor = 2; AlignedVector64 output(input_size, 0); for (auto _ : state) { @@ -177,11 +179,10 @@ static void BM_EltwiseReduceModAVX512BitShift52GT( size_t input_size = state.range(0); size_t modulus = 1152921504606877697; - // AlignedVector64 input1(input_size, 1); auto input1 = GenerateInsecureUniformRandomValues( input_size, 4503599627370496, 100 * modulus); const uint64_t input_mod_factor = modulus; - const uint64_t output_mod_factor = 1; + const uint64_t output_mod_factor = 2; AlignedVector64 output(input_size, 0); for (auto _ : state) { @@ -201,10 +202,10 @@ static void BM_EltwiseReduceModAVX512BitShift52LT( size_t input_size = state.range(0); size_t modulus = 1073741441; - // AlignedVector64 input1(input_size, 1); - auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 2 * modulus); + auto input1 = + GenerateInsecureUniformRandomValues(input_size, 0, 2251799813685248); const uint64_t input_mod_factor = modulus; - const uint64_t output_mod_factor = 1; + const uint64_t output_mod_factor = 2; AlignedVector64 output(input_size, 0); for (auto _ : state) { diff --git a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp index 194dcb3f..349c0e04 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -56,7 +56,7 @@ void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); v_op = _mm512_hexl_barrett_reduce64( - v_op, v_modulus, v_mu, prod_right_shift, v_neg_mod); + v_op, v_modulus, v_mu, v_mu, prod_right_shift, v_neg_mod); __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); v_to_add = _mm512_sub_epi64(v_to_add, v_diff); diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp index 53d16cfd..8018fe32 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp @@ -48,8 +48,7 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, modulus) .BarrettFactor(); - // Single-worded Barrett reduction. - uint64_t barrett_factor52 = MultiplyFactor(1, 52, modulus).BarrettFactor(); + uint64_t barrett_factor_52 = MultiplyFactor(1, 52, modulus).BarrettFactor(); if (BitShift == 64) { // Single-worded Barrett reduction. @@ -57,6 +56,7 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, } __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); + __m512i v_bf_52 = _mm512_set1_epi64(static_cast(barrett_factor_52)); // Deals with n not divisible by 8 uint64_t n_mod_8 = n_tmp % 8; @@ -79,7 +79,7 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, for (size_t i = 0; i < n_tmp; i += 8) { __m512i v_op = _mm512_loadu_si512(v_operand); v_op = _mm512_hexl_barrett_reduce64( - v_op, v_modulus, v_bf, prod_right_shift, v_neg_mod); + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_op); @@ -90,7 +90,7 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, for (size_t i = 0; i < n_tmp; i += 8) { __m512i v_op = _mm512_loadu_si512(v_operand); v_op = _mm512_hexl_barrett_reduce64( - v_op, v_modulus, v_bf, prod_right_shift, v_neg_mod); + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_op); diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 70983759..aac503f9 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -369,45 +369,11 @@ inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value); } -//// Returns x mod q, computed via Barrett reduction -//// @param q_barr floor(2^BitShift / q) -// template -// inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, -// __m512i q_barr) { -// HEXL_CHECK(BitShift == 52 || BitShift == 64, -// "Invalid bitshift " << BitShift << "; need 52 or 64"); - -// #ifdef HEXL_HAS_AVX512IFMA -// if (BitShift == 52) { -// __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); -// // Barrett subtraction -// // tmp[0] = input - tmp[1] * q; -// __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); -// x = _mm512_sub_epi64(x, tmp1_times_mod); -// } -// #endif -// if (BitShift == 64) { -// __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr); -// // Barrett subtraction -// // tmp[0] = input - tmp[1] * q; -// __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); -// x = _mm512_sub_epi64(x, tmp1_times_mod); -// } - -// // Correction -// if (OutputModFactor == 2) { -// return x; -// } else { -// x = _mm512_hexl_small_mod_epu64(x, q); -// return x; -// } -//} - // Returns x mod q, computed via Barrett reduction // @param q_barr floor(2^BitShift / q) template inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, - __m512i q_barr, + __m512i q_barr, __m512i q_barr_52, uint64_t prod_right_shift, __m512i v_neg_mod) { HEXL_CHECK(BitShift == 52 || BitShift == 64, @@ -416,7 +382,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, #ifdef HEXL_HAS_AVX512IFMA if (BitShift == 52) { uint64_t match_value = 1; - __m512i two_pow_fiftytwo = _mm512_set1_epi64(4503599627370496); + __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248); __m512i mask = _mm512_hexl_cmpge_epu64(x, two_pow_fiftytwo, match_value); uint64_t sum = _mm512_reduce_add_epi64(mask); if (sum > 0) { @@ -437,10 +403,8 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, // Z = prod_lo - (p * q_hat)_lo x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); - // x = _mm512_hexl_small_mod_epu64<2>(x, q); - } else { - __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr); + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52); // Barrett subtraction // tmp[0] = input - tmp[1] * q; __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); diff --git a/test/test-avx512-util.cpp b/test/test-avx512-util.cpp index 6dfdaf96..d5d96811 100644 --- a/test/test-avx512-util.cpp +++ b/test/test-avx512-util.cpp @@ -366,7 +366,7 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { __m512i varg1 = _mm512_set_epi64(arg1[7], arg1[6], arg1[5], arg1[4], arg1[3], arg1[2], arg1[1], arg1[0]); - __m512i c = _mm512_hexl_barrett_reduce64(varg1, vmodulus, vbarr, + __m512i c = _mm512_hexl_barrett_reduce64(varg1, vmodulus, vbarr, vbarr, prod_right_shift, v_neg_mod); std::vector result = ExtractValues(c); diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index a66f695e..25f7964f 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -18,7 +18,24 @@ namespace intel { namespace hexl { #ifdef HEXL_HAS_AVX512DQ -TEST(EltwiseReduceMod, avx512_mod_1) { +TEST(EltwiseReduceMod, avx512_64_mod_1) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + std::vector op{0, 111, 250, 340, 769, 900, 1200, 1530}; + std::vector exp_out{0, 111, 250, 340, 0, 131, 431, 761}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + uint64_t modulus = 769; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + EltwiseReduceModAVX512<64>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); +} + +TEST(EltwiseReduceMod, avx512_52_mod_1) { if (!has_avx512dq) { GTEST_SKIP(); } From 37ba0325f31e94af8861553a932d4bd63e6ae5a8 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 05:57:17 -0700 Subject: [PATCH 08/18] update include --- hexl/eltwise/eltwise-mult-mod-internal.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/hexl/eltwise/eltwise-mult-mod-internal.hpp b/hexl/eltwise/eltwise-mult-mod-internal.hpp index d778a09f..02cdcde1 100644 --- a/hexl/eltwise/eltwise-mult-mod-internal.hpp +++ b/hexl/eltwise/eltwise-mult-mod-internal.hpp @@ -4,6 +4,7 @@ #pragma once #include +#include #include "eltwise/eltwise-mult-mod-internal.hpp" #include "hexl/eltwise/eltwise-reduce-mod.hpp" From 21e208401e5442f40e75acae64d8479fe0a96b63 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 05:58:54 -0700 Subject: [PATCH 09/18] fix formatting --- hexl/eltwise/eltwise-mult-mod-internal.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hexl/eltwise/eltwise-mult-mod-internal.hpp b/hexl/eltwise/eltwise-mult-mod-internal.hpp index 02cdcde1..9d153e0e 100644 --- a/hexl/eltwise/eltwise-mult-mod-internal.hpp +++ b/hexl/eltwise/eltwise-mult-mod-internal.hpp @@ -3,9 +3,10 @@ #pragma once -#include #include +#include + #include "eltwise/eltwise-mult-mod-internal.hpp" #include "hexl/eltwise/eltwise-reduce-mod.hpp" #include "hexl/number-theory/number-theory.hpp" From 4898a4963dff07e2aaebb92da832340575c4d4a9 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 06:05:18 -0700 Subject: [PATCH 10/18] fix formatting --- hexl/eltwise/eltwise-reduce-mod-internal.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hexl/eltwise/eltwise-reduce-mod-internal.hpp b/hexl/eltwise/eltwise-reduce-mod-internal.hpp index 7f23a5c1..907b081e 100644 --- a/hexl/eltwise/eltwise-reduce-mod-internal.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-internal.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace intel { namespace hexl { From ec8a99725b1ed8acf0c56c38e1d4a6f8669e353d Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 09:51:18 -0700 Subject: [PATCH 11/18] rename arg, better perf --- hexl/include/hexl/util/msvc.hpp | 5 +---- hexl/util/avx512-util.hpp | 13 +++++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/hexl/include/hexl/util/msvc.hpp b/hexl/include/hexl/util/msvc.hpp index 291e9549..6e02d1ba 100644 --- a/hexl/include/hexl/util/msvc.hpp +++ b/hexl/include/hexl/util/msvc.hpp @@ -276,10 +276,7 @@ inline uint64_t MSB(uint64_t input) { HEXL_CHECK(input == 0, "input cannot be 0: Got " << input); unsigned long index{0}; // NOLINT(runtime/int) _BitScanReverse64(&index, input); - if (index >= 0 && input > 0) { - return static_cast(index); - } - return 0; + return index; } #define HEXL_LOOP_UNROLL_4 \ diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index aac503f9..7b94958f 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -373,7 +373,8 @@ inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, // @param q_barr floor(2^BitShift / q) template inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, - __m512i q_barr, __m512i q_barr_52, + __m512i q_barr_64, + __m512i q_barr_52, uint64_t prod_right_shift, __m512i v_neg_mod) { HEXL_CHECK(BitShift == 52 || BitShift == 64, @@ -383,9 +384,9 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, if (BitShift == 52) { uint64_t match_value = 1; __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248); - __m512i mask = _mm512_hexl_cmpge_epu64(x, two_pow_fiftytwo, match_value); - uint64_t sum = _mm512_reduce_add_epi64(mask); - if (sum > 0) { + __mmask8 mask = + _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT); + if (mask != 0) { __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); __m512i x_intr = _mm512_slli_epi64(x, static_cast(12ULL)); __m512i x_lo = @@ -399,7 +400,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); // alpha - beta == 52, so we only need high 52 bits - __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr); + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64); // Z = prod_lo - (p * q_hat)_lo x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); @@ -413,7 +414,7 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, } #endif if (BitShift == 64) { - __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr); + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64); // Barrett subtraction // tmp[0] = input - tmp[1] * q; __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); From 6e9df9e43ad81f9003cd3fae2f196433fa81424d Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 10:43:15 -0700 Subject: [PATCH 12/18] fix windows build --- hexl/include/hexl/util/msvc.hpp | 1 - hexl/util/avx512-util.hpp | 1 - 2 files changed, 2 deletions(-) diff --git a/hexl/include/hexl/util/msvc.hpp b/hexl/include/hexl/util/msvc.hpp index 6e02d1ba..06654d43 100644 --- a/hexl/include/hexl/util/msvc.hpp +++ b/hexl/include/hexl/util/msvc.hpp @@ -273,7 +273,6 @@ inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, // Returns most-significant bit of the input inline uint64_t MSB(uint64_t input) { - HEXL_CHECK(input == 0, "input cannot be 0: Got " << input); unsigned long index{0}; // NOLINT(runtime/int) _BitScanReverse64(&index, input); return index; diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 7b94958f..c992d1ad 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -382,7 +382,6 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, #ifdef HEXL_HAS_AVX512IFMA if (BitShift == 52) { - uint64_t match_value = 1; __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248); __mmask8 mask = _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT); From ef648dda733f40ae67db57713b40639f81f0032a Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 11:16:25 -0700 Subject: [PATCH 13/18] fix IFMA52 unit-tests --- test/test-eltwise-reduce-mod-avx512.cpp | 46 +++++++++++++------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index 25f7964f..6b6b9853 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -35,6 +35,7 @@ TEST(EltwiseReduceMod, avx512_64_mod_1) { CheckEqual(result, exp_out); } +#ifdef HEXL_HAS_AVX512IFMA TEST(EltwiseReduceMod, avx512_52_mod_1) { if (!has_avx512dq) { GTEST_SKIP(); @@ -52,6 +53,29 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) { CheckEqual(result, exp_out); } +TEST(EltwiseReduceMod, avx512Big_mod_1) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + std::vector op{914704788761805005, 224925333812073588, + 592788284123677125, 142439467624940029, + 146023272535470246, 979015887843024185, + 496780369302017539, 1073741441}; + std::vector exp_out{802487803, 754009873, 962097738, 36142730, + 687617508, 519876583, 630345322, 0}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + uint64_t modulus = 1073741441; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + + EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); +} +#endif + TEST(EltwiseReduceMod, avx512_2_1) { if (!has_avx512dq) { GTEST_SKIP(); @@ -103,28 +127,6 @@ TEST(EltwiseReduceMod, avx512_4_2) { CheckEqual(result, exp_out); } -TEST(EltwiseReduceMod, avx512Big_mod_1) { - if (!has_avx512dq) { - GTEST_SKIP(); - } - - std::vector op{914704788761805005, 224925333812073588, - 592788284123677125, 142439467624940029, - 146023272535470246, 979015887843024185, - 496780369302017539, 1073741441}; - std::vector exp_out{802487803, 754009873, 962097738, 36142730, - 687617508, 519876583, 630345322, 0}; - std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; - - uint64_t modulus = 1073741441; - const uint64_t input_mod_factor = modulus; - const uint64_t output_mod_factor = 1; - - EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, - input_mod_factor, output_mod_factor); - CheckEqual(result, exp_out); -} - // Checks AVX512 and native EltwiseReduceMod implementations match with randomly // generated inputs TEST(EltwiseReduceMod, AVX512Big_0_1) { From 4ec80daaebc5bb60eabe623e2f06bbc933f3537d Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 12:46:19 -0700 Subject: [PATCH 14/18] fix debug benchmark --- benchmark/bench-eltwise-reduce-mod.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index 48dfb3aa..6846fcf2 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -93,7 +93,7 @@ BENCHMARK(BM_EltwiseReduceModNative) // state[0] is the degree static void BM_EltwiseReduceModAVX512(benchmark::State& state) { // NOLINT size_t input_size = state.range(0); - size_t modulus = 1152921504606877697; + size_t modulus = 0xffffffffffc0001ULL; auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); @@ -102,8 +102,8 @@ static void BM_EltwiseReduceModAVX512(benchmark::State& state) { // NOLINT AlignedVector64 output(input_size, 0); for (auto _ : state) { - EltwiseReduceModAVX512(output.data(), input1.data(), input_size, modulus, - input_mod_factor, output_mod_factor); + EltwiseReduceModAVX512<64>(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); } } @@ -121,7 +121,7 @@ BENCHMARK(BM_EltwiseReduceModAVX512) static void BM_EltwiseReduceModAVX512BitShift64( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); - size_t modulus = 1152921504606877697; + size_t modulus = 0xffffffffffc0001ULL; auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); @@ -142,14 +142,14 @@ BENCHMARK(BM_EltwiseReduceModAVX512BitShift64) ->Args({16384}); #endif -//================================================================= +////================================================================= #ifdef HEXL_HAS_AVX512IFMA // state[0] is the degree static void BM_EltwiseReduceModAVX512BitShift52( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); - size_t modulus = 1152921504606877697; + size_t modulus = 0xffffffffffc0001ULL; auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 100 * modulus); @@ -170,14 +170,14 @@ BENCHMARK(BM_EltwiseReduceModAVX512BitShift52) ->Args({16384}); #endif -//================================================================= +////================================================================= #ifdef HEXL_HAS_AVX512IFMA // state[0] is the degree static void BM_EltwiseReduceModAVX512BitShift52GT( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); - size_t modulus = 1152921504606877697; + size_t modulus = 0xffffffffffc0001ULL; auto input1 = GenerateInsecureUniformRandomValues( input_size, 4503599627370496, 100 * modulus); @@ -200,7 +200,7 @@ BENCHMARK(BM_EltwiseReduceModAVX512BitShift52GT) static void BM_EltwiseReduceModAVX512BitShift52LT( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); - size_t modulus = 1073741441; + size_t modulus = 0xffffffffffc0001ULL; auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, 2251799813685248); From 58b15c60503ea3cfa48141393caa4d6931150955 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 13:13:43 -0700 Subject: [PATCH 15/18] avoid unused error --- hexl/util/avx512-util.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index c992d1ad..43718429 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -377,6 +377,9 @@ inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, __m512i q_barr_52, uint64_t prod_right_shift, __m512i v_neg_mod) { + HEXL_UNUSED(q_barr_52); + HEXL_UNUSED(prod_right_shift); + HEXL_UNUSED(v_neg_mod); HEXL_CHECK(BitShift == 52 || BitShift == 64, "Invalid bitshift " << BitShift << "; need 52 or 64"); From 5f3473bc0d763386c8559108064c5ceeea9a84b5 Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 13:55:19 -0700 Subject: [PATCH 16/18] update modulo --- hexl/experimental/seal/ckks-switch-key.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hexl/experimental/seal/ckks-switch-key.cpp b/hexl/experimental/seal/ckks-switch-key.cpp index f8761650..3d961325 100644 --- a/hexl/experimental/seal/ckks-switch-key.cpp +++ b/hexl/experimental/seal/ckks-switch-key.cpp @@ -71,9 +71,9 @@ void CkksSwitchKey(uint64_t* result, const uint64_t* t_target_iter_ptr, } } else { // Perform RNS conversion (modular reduction) - intel::hexl::EltwiseReduceMod(t_ntt_ptr, - &t_target_ptr[j * coeff_count], - coeff_count, moduli[key_index], 0, 1); + intel::hexl::EltwiseReduceMod( + t_ntt_ptr, &t_target_ptr[j * coeff_count], coeff_count, + moduli[key_index], moduli[key_index], 1); } // NTT conversion lazy outputs in [0, 4q) @@ -149,7 +149,7 @@ void CkksSwitchKey(uint64_t* result, const uint64_t* t_target_iter_ptr, // TODO(fboemer): Use input_mod_factor != 0 when qk / qi < 4 // TODO(fboemer): Use output_mod_factor == 4? - uint64_t input_mod_factor = (qk > qi) ? 0 : 2; + uint64_t input_mod_factor = (qk > qi) ? moduli[i] : 2; if (qk > qi) { intel::hexl::EltwiseReduceMod(t_ntt_ptr, t_last, coeff_count, moduli[i], input_mod_factor, 1); From df4b0b279b5fc635cb92c271edba6314a34390ef Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 15:34:46 -0700 Subject: [PATCH 17/18] update test --- test/test-avx512-util.cpp | 73 ++++++++++++++------------------------- 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/test/test-avx512-util.cpp b/test/test-avx512-util.cpp index d5d96811..ede401a4 100644 --- a/test/test-avx512-util.cpp +++ b/test/test-avx512-util.cpp @@ -294,54 +294,31 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { } // Small - // { - // __m512i a = _mm512_set_epi64(12, 11, 10, 8, 6, 4, 2, 0); - - // std::vector moduli{2, 2, 3, 4, 5, 6, 7, 8}; - // std::vector barrs(moduli.size()); - // for (size_t i = 0; i < barrs.size(); ++i) { - // barrs[i] = MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); - // } - - // // Multi-word Barrett reduction precomputation - // std::vector ceil_log_mod(moduli.size()); - // constexpr int64_t beta = -2; - // for (size_t i = 0; i < ceil_log_mod.size(); ++i) { - // ceil_log_mod[i] = Log2(moduli[i]) + 1; - // } - - // std::vector prod_right_shift(moduli.size()); - // for (size_t i = 0; i < prod_right_shift.size(); ++i) { - // prod_right_shift[i] = ceil_log_mod[i] + beta; - // } - - // //const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from - // Algorithm 2 - // // uint64_t prod_right_shift = ceil_log_mod + beta; - // __m512i v_neg_mod = - // _mm512_set_epi64(-static_cast(moduli[7]), - // -static_cast(moduli[6]), - // -static_cast(moduli[5]), - // -static_cast(moduli[4]), - // -static_cast(moduli[3]), - // -static_cast(moduli[2]), - // -static_cast(moduli[1]), - // -static_cast(moduli[0])); - - // __m512i vmoduli = - // _mm512_set_epi64(moduli[7], moduli[6], moduli[5], moduli[4], - // moduli[3], - // moduli[2], moduli[1], moduli[0]); - // __m512i vbarrs = _mm512_set_epi64(barrs[7], barrs[6], barrs[5], - // barrs[4], - // barrs[3], barrs[2], barrs[1], - // barrs[0]); - - // __m512i expected_out = _mm512_set_epi64(4, 4, 4, 3, 2, 1, 0, 0); - - // __m512i c = _mm512_hexl_barrett_reduce64(a, vmoduli, vbarrs, - // prod_right_shift, v_neg_mod); AssertEqual(c, expected_out); - // } + { + __m512i a = _mm512_set_epi64(12, 11, 10, 8, 6, 4, 2, 0); + + std::vector moduli{5, 5, 5, 5, 5, 5, 5, 5}; + std::vector barrs(moduli.size()); + for (size_t i = 0; i < barrs.size(); ++i) { + barrs[i] = MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); + } + + // Multi-word Barrett reduction precomputation + constexpr int64_t beta = -2; + uint64_t ceil_log_mod = Log2(moduli[0]) + 1; + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(moduli[0])); + + __m512i vmoduli = _mm512_set1_epi64(moduli[0]); + __m512i vbarrs = _mm512_set_epi64(barrs[7], barrs[6], barrs[5], barrs[4], + barrs[3], barrs[2], barrs[1], barrs[0]); + + __m512i expected_out = _mm512_set_epi64(2, 1, 0, 3, 1, 4, 2, 0); + + __m512i c = _mm512_hexl_barrett_reduce64(a, vmoduli, vbarrs, vbarrs, + prod_right_shift, v_neg_mod); + AssertEqual(c, expected_out); + } // Random { From 0bbd420a8fe585fc8e4b89f5f20a7c7db537f61c Mon Sep 17 00:00:00 2001 From: GelilaSeifu Date: Fri, 8 Oct 2021 15:53:58 -0700 Subject: [PATCH 18/18] simplify test --- test/test-avx512-util.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/test/test-avx512-util.cpp b/test/test-avx512-util.cpp index ede401a4..5247864c 100644 --- a/test/test-avx512-util.cpp +++ b/test/test-avx512-util.cpp @@ -297,21 +297,16 @@ TEST(AVX512, _mm512_hexl_barrett_reduce64) { { __m512i a = _mm512_set_epi64(12, 11, 10, 8, 6, 4, 2, 0); - std::vector moduli{5, 5, 5, 5, 5, 5, 5, 5}; - std::vector barrs(moduli.size()); - for (size_t i = 0; i < barrs.size(); ++i) { - barrs[i] = MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); - } + uint64_t modulus = 5; + uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + __m512i vmoduli = _mm512_set1_epi64(modulus); + __m512i vbarrs = _mm512_set1_epi64(barrett_factor); // Multi-word Barrett reduction precomputation constexpr int64_t beta = -2; - uint64_t ceil_log_mod = Log2(moduli[0]) + 1; + uint64_t ceil_log_mod = Log2(modulus) + 1; uint64_t prod_right_shift = ceil_log_mod + beta; - __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(moduli[0])); - - __m512i vmoduli = _mm512_set1_epi64(moduli[0]); - __m512i vbarrs = _mm512_set_epi64(barrs[7], barrs[6], barrs[5], barrs[4], - barrs[3], barrs[2], barrs[1], barrs[0]); + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); __m512i expected_out = _mm512_set_epi64(2, 1, 0, 3, 1, 4, 2, 0);