diff --git a/crypto/fipsmodule/bn/rsaz_exp_x2.c b/crypto/fipsmodule/bn/rsaz_exp_x2.c index 4aee1ada4a..a967671f64 100644 --- a/crypto/fipsmodule/bn/rsaz_exp_x2.c +++ b/crypto/fipsmodule/bn/rsaz_exp_x2.c @@ -94,8 +94,8 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1, const uint64_t *b, const uint64_t *m, uint64_t k0); int ret = 0; - // Number of word-size (uint64_t) digits to store in redundant - // representation. + // Number of word-size (uint64_t) digits to store values in + // redundant representation. int red_digits = number_of_digits(modlen + 2, DIGIT_SIZE); // n = modlen, d = DIGIT_SIZE, s = d * ceil((n+2)/d) > n @@ -124,7 +124,7 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1, uint64_t *storage = NULL; uint64_t *storage_aligned = NULL; int storage_len_bytes = 7 * regs_capacity * sizeof(uint64_t) - + 64; + + 64; // alignment const uint64_t *exp[2] = {0}; uint64_t k0[2] = {0}; @@ -177,17 +177,19 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1, // - We have AMM(t, 2^k) = R^4 * 2^{4*(s-n)} / R'^2 mod m -- (2) // = R'^4 / R'^2 mod m // = R'^2 mod m + // For example, for n = 1024, s = 1040, k = 64, + // RR = 2^2048 mod m, RR' = 2^2080 mod m OPENSSL_memset(coeff_red, 0, red_digits * sizeof(uint64_t)); // coeff_red = 2^k = 1 << bitlen_diff taking into account the // redundant representation in digits of DIGIT_SIZE bits set_bit(coeff_red, 64 * (int)(bitlen_diff / DIGIT_SIZE) + bitlen_diff % DIGIT_SIZE); - amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1); - amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1); + amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1); // (1) for m1 + amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1); // (2) for m1 - amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2); - amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2); + amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2); // (1) for m2 + amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2); // (2) for m2 exp[0] = exp1; exp[1] = exp2; @@ -316,6 +318,11 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out, red_table = red_X + 2 * red_digits; expz = red_table + 2 * red_digits * two_to_exp_win_size; + // Compute table of powers base^i mod m, + // i = 0, ..., (2^EXP_WIN_SIZE) - 1 + // using the dual multiplication. Each table entry contains + // base1^i mod m1, then base2^i mod m2. + red_X[0 * red_digits] = 1; red_X[1 * red_digits] = 1; damm(&red_table[0 * 2 * red_digits], (const uint64_t*)red_X, rr, m, k0); @@ -367,9 +374,17 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out, // `rem` is { 1024, 1536, 2048 } % 5 which is { 4, 1, 3 } // respectively. // - // If this assertion ever fails the fix above is easy. + // If this assertion ever fails then we should set this easy + // fix exp_bit_no = modlen - exp_win_size assert(rem == 4 || rem == 1 || rem == 3); + + // Find the location of the 5-bit window in the exponent which + // is stored in 64-bit digits. Left pad it with 0s to form a + // 64-bit digit to become an index in the precomputed table. + // The window location in the exponent is identified by its + // least significant bit `exp_bit_no`. + #define EXP_CHUNK(i) (exp_chunk_no) + ((i) * (exp_digits + 1)) #define EXP_CHUNK1(i) (exp_chunk_no) + 1 + ((i) * (exp_digits + 1)) @@ -395,7 +410,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out, exp_chunk_no = exp_bit_no / 64; exp_chunk_shift = exp_bit_no % 64; { - red_table_idx_1 = expz[exp_chunk_no + 0 * (exp_digits + 1)]; + red_table_idx_1 = expz[EXP_CHUNK(0)]; T = expz[EXP_CHUNK1(0)]; red_table_idx_1 >>= exp_chunk_shift; @@ -408,7 +423,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out, red_table_idx_1 &= table_idx_mask; } { - red_table_idx_2 = expz[exp_chunk_no + 1 * (exp_digits + 1)]; + red_table_idx_2 = expz[EXP_CHUNK(1)]; T = expz[EXP_CHUNK1(1)]; red_table_idx_2 >>= exp_chunk_shift; @@ -425,7 +440,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out, (int)red_table_idx_1, (int)red_table_idx_2); } - // Series of squaring + // The number of squarings is equal to the window size. DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0); DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0); DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0); diff --git a/crypto/impl_dispatch_test.cc b/crypto/impl_dispatch_test.cc index 814c94fde1..6adf990e14 100644 --- a/crypto/impl_dispatch_test.cc +++ b/crypto/impl_dispatch_test.cc @@ -247,36 +247,89 @@ TEST_F(ImplDispatchTest, SHA512) { } #endif // OPENSSL_AARCH64 - #if defined(OPENSSL_X86_64) && !defined(MY_ASSEMBLER_IS_TOO_OLD_512AVX) && \ defined(RSAZ_512_ENABLED) + +#include "test/file_test.h" + +static bssl::UniquePtr GetBIGNUM(FileTest *t, const char *attr); + +static bssl::UniquePtr GetBIGNUM(FileTest *t, const char *attr) { + std::string hex; + if (!t->GetAttribute(&hex, attr)) { + return nullptr; + } + + BIGNUM *raw = NULL; + int size = BN_hex2bn(&raw, hex.c_str()); + if (size != static_cast(hex.size())) { + t->PrintLine("Could not decode '%s'.", hex.c_str()); + return nullptr; + } + + bssl::UniquePtr ret; + (&ret)->reset(raw); + return ret; +} + TEST_F(ImplDispatchTest, BN_mod_exp_mont_consttime_x2) { - AssertFunctionsHit( + FileTestGTest( + "crypto/fipsmodule/bn/test/mod_exp_x2_tests.txt", + [&](FileTest *t) { + AssertFunctionsHit( { - {kFlag_RSAZ_mod_exp_avx512_x2, - is_x86_64_ && - !is_assembler_too_old_avx512 && - ifma_avx512}, + {kFlag_RSAZ_mod_exp_avx512_x2, + is_x86_64_ && + !is_assembler_too_old_avx512 && + ifma_avx512}, }, - [] { - uint64_t res1 = 0; - uint64_t base1 = 0; - uint64_t exp1 = 0; - uint64_t m1 = 0; - uint64_t rr1 = 0; - uint64_t k0_1 = 0; - uint64_t res2 = 0; - uint64_t base2 = 0; - uint64_t exp2 = 0; - uint64_t m2 = 0; - uint64_t rr2 = 0; - uint64_t k0_2 = 0; - int modlen = 0; - - RSAZ_mod_exp_avx512_x2(&res1, &base1, &exp1, &m1, &rr1, k0_1, - &res2, &base2, &exp2, &m2, &rr2, k0_2, - modlen); + [&]() { + BN_CTX *ctx = BN_CTX_new(); + BN_CTX_start(ctx); + bssl::UniquePtr a1 = GetBIGNUM(t, "A1"); + bssl::UniquePtr e1 = GetBIGNUM(t, "E1"); + bssl::UniquePtr m1 = GetBIGNUM(t, "M1"); + bssl::UniquePtr mod_exp1 = GetBIGNUM(t, "ModExp1"); + ASSERT_TRUE(a1); + ASSERT_TRUE(e1); + ASSERT_TRUE(m1); + ASSERT_TRUE(mod_exp1); + + bssl::UniquePtr a2 = GetBIGNUM(t, "A2"); + bssl::UniquePtr e2 = GetBIGNUM(t, "E2"); + bssl::UniquePtr m2 = GetBIGNUM(t, "M2"); + bssl::UniquePtr mod_exp2 = GetBIGNUM(t, "ModExp2"); + ASSERT_TRUE(a2); + ASSERT_TRUE(e2); + ASSERT_TRUE(m2); + ASSERT_TRUE(mod_exp2); + + bssl::UniquePtr ret1(BN_new()); + ASSERT_TRUE(ret1); + + bssl::UniquePtr ret2(BN_new()); + ASSERT_TRUE(ret2); + + ASSERT_TRUE(BN_nnmod(a1.get(), a1.get(), m1.get(), ctx)); + ASSERT_TRUE(BN_nnmod(a2.get(), a2.get(), m2.get(), ctx)); + + BN_MONT_CTX *mont1 = NULL; + BN_MONT_CTX *mont2 = NULL; + + ASSERT_TRUE(mont1 = BN_MONT_CTX_new()); + ASSERT_TRUE(BN_MONT_CTX_set(mont1, m1.get(), ctx)); + ASSERT_TRUE(mont2 = BN_MONT_CTX_new()); + ASSERT_TRUE(BN_MONT_CTX_set(mont2, m2.get(), ctx)); + + BN_mod_exp_mont_consttime_x2(ret1.get(), a1.get(), e1.get(), m1.get(), mont1, + ret2.get(), a2.get(), e2.get(), m2.get(), mont2, + ctx); + + BN_CTX_end(ctx); + BN_MONT_CTX_free(mont1); + BN_MONT_CTX_free(mont2); }); + }); } #endif // OPENSSL_X86_64 && !MY_ASSEMBLER_IS_TOO_OLD_512AVX && RSAZ_512_ENABLED