Skip to content

Commit

Permalink
Fboemer/fix 32 bit invntt (#73)
Browse files Browse the repository at this point in the history
* Fix 32-bit AVX512DQ InvNTT

* Refactor NTT tests for better coverage
  • Loading branch information
fboemer committed Nov 8, 2021
1 parent e2acbab commit 52117a9
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 320 deletions.
2 changes: 1 addition & 1 deletion benchmark/bench-ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMALazy)
static void BM_InvNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT
size_t ntt_size = state.range(0);
uint64_t output_mod_factor = state.range(1);
size_t modulus = GeneratePrimes(1, 30, true, ntt_size)[0];
size_t modulus = GeneratePrimes(1, 29, true, ntt_size)[0];

auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus);
NTT ntt(ntt_size, modulus);
Expand Down
32 changes: 30 additions & 2 deletions hexl/include/hexl/ntt/ntt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,43 @@ class NTT {

/// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the
/// inverse transform
static const size_t s_max_inv_32_modulus{1ULL << (32 - 1)};
static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)};

/// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward
/// transform
static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)};

/// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse
/// transform
static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 1)};
static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)};

/// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse
/// transform
static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)};

static size_t s_max_fwd_modulus(int bit_shift) {
if (bit_shift == 32) {
return s_max_fwd_32_modulus;
} else if (bit_shift == 52) {
return s_max_fwd_ifma_modulus;
} else if (bit_shift == 64) {
return 1ULL << MaxModulusBits();
}
HEXL_CHECK(false, "Invalid bit_shift " << bit_shift);
return 0;
}

static size_t s_max_inv_modulus(int bit_shift) {
if (bit_shift == 32) {
return s_max_inv_32_modulus;
} else if (bit_shift == 52) {
return s_max_inv_ifma_modulus;
} else if (bit_shift == 64) {
return 1ULL << MaxModulusBits();
}
HEXL_CHECK(false, "Invalid bit_shift " << bit_shift);
return 0;
}

private:
void ComputeRootOfUnityPowers();
Expand Down
4 changes: 2 additions & 2 deletions hexl/include/hexl/util/aligned-allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class AlignedAllocator {
explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept
: m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {}

AlignedAllocator(const AlignedAllocator& src)
: m_alloc_impl(src.m_alloc_impl) {}
AlignedAllocator(const AlignedAllocator& src) = default;
AlignedAllocator& operator=(const AlignedAllocator& src) = default;

template <typename U>
AlignedAllocator(const AlignedAllocator<U, Alignment>& src)
Expand Down
5 changes: 3 additions & 2 deletions hexl/ntt/fwd-ntt-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ void ForwardTransformToBitReverseAVX512(
uint64_t output_mod_factor, uint64_t recursion_depth,
uint64_t recursion_half) {
HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
HEXL_CHECK(modulus < MaximumValue(BitShift) / 4,
HEXL_CHECK(modulus < NTT::s_max_fwd_modulus(BitShift),
"modulus " << modulus << " too large for BitShift " << BitShift
<< " => maximum value " << MaximumValue(BitShift) / 4);
<< " => maximum value "
<< NTT::s_max_fwd_modulus(BitShift));
HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift),
"precon_root_of_unity_powers too large");
HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
Expand Down
5 changes: 3 additions & 2 deletions hexl/ntt/inv-ntt-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,10 @@ void InverseTransformFromBitReverseAVX512(
"InverseTransformFromBitReverseAVX512 doesn't support small "
"transforms. Need n >= 16, got n = "
<< n);
HEXL_CHECK(modulus < MaximumValue(BitShift) / 2,
HEXL_CHECK(modulus < NTT::s_max_inv_modulus(BitShift),
"modulus " << modulus << " too large for BitShift " << BitShift
<< " => maximum value " << MaximumValue(BitShift) / 2);
<< " => maximum value "
<< NTT::s_max_inv_modulus(BitShift));
HEXL_CHECK_BOUNDS(precon_inv_root_of_unity_powers, n, MaximumValue(BitShift),
"precon_inv_root_of_unity_powers too large");
HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
Expand Down
Loading

0 comments on commit 52117a9

Please sign in to comment.