Skip to content

Commit

Permalink
fix dispatch test
Browse files Browse the repository at this point in the history
  • Loading branch information
pittma committed Sep 6, 2024
1 parent e626c2c commit 92b9e3f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 35 deletions.
37 changes: 26 additions & 11 deletions crypto/fipsmodule/bn/rsaz_exp_x2.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
const uint64_t *b, const uint64_t *m, uint64_t k0);
int ret = 0;

// Number of word-size (uint64_t) digits to store in redundant
// representation.
// Number of word-size (uint64_t) digits to store values in
// redundant representation.
int red_digits = number_of_digits(modlen + 2, DIGIT_SIZE);

// n = modlen, d = DIGIT_SIZE, s = d * ceil((n+2)/d) > n
Expand Down Expand Up @@ -124,7 +124,7 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
uint64_t *storage = NULL;
uint64_t *storage_aligned = NULL;
int storage_len_bytes = 7 * regs_capacity * sizeof(uint64_t)
+ 64;
+ 64; // alignment

const uint64_t *exp[2] = {0};
uint64_t k0[2] = {0};
Expand Down Expand Up @@ -177,17 +177,19 @@ int RSAZ_mod_exp_avx512_x2(uint64_t *res1,
// - We have AMM(t, 2^k) = R^4 * 2^{4*(s-n)} / R'^2 mod m -- (2)
// = R'^4 / R'^2 mod m
// = R'^2 mod m
// For example, for n = 1024, s = 1040, k = 64,
// RR = 2^2048 mod m, RR' = 2^2080 mod m

OPENSSL_memset(coeff_red, 0, red_digits * sizeof(uint64_t));
// coeff_red = 2^k = 1 << bitlen_diff taking into account the
// redundant representation in digits of DIGIT_SIZE bits
set_bit(coeff_red, 64 * (int)(bitlen_diff / DIGIT_SIZE) + bitlen_diff % DIGIT_SIZE);

amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1);
amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1);
amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1); // (1) for m1
amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1); // (2) for m1

amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2);
amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2);
amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2); // (1) for m2
amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2); // (2) for m2

exp[0] = exp1;
exp[1] = exp2;
Expand Down Expand Up @@ -316,6 +318,11 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
red_table = red_X + 2 * red_digits;
expz = red_table + 2 * red_digits * two_to_exp_win_size;

// Compute table of powers base^i mod m,
// i = 0, ..., (2^EXP_WIN_SIZE) - 1
// using the dual multiplication. Each table entry contains
// base1^i mod m1, then base2^i mod m2.

red_X[0 * red_digits] = 1;
red_X[1 * red_digits] = 1;
damm(&red_table[0 * 2 * red_digits], (const uint64_t*)red_X, rr, m, k0);
Expand Down Expand Up @@ -367,9 +374,17 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
// `rem` is { 1024, 1536, 2048 } % 5 which is { 4, 1, 3 }
// respectively.
//
// If this assertion ever fails the fix above is easy.
// If this assertion ever fails then we should set this easy
// fix exp_bit_no = modlen - exp_win_size
assert(rem == 4 || rem == 1 || rem == 3);


// Find the location of the 5-bit window in the exponent which
// is stored in 64-bit digits. Left pad it with 0s to form a
// 64-bit digit to become an index in the precomputed table.
// The window location in the exponent is identified by its
// least significant bit `exp_bit_no`.

#define EXP_CHUNK(i) (exp_chunk_no) + ((i) * (exp_digits + 1))
#define EXP_CHUNK1(i) (exp_chunk_no) + 1 + ((i) * (exp_digits + 1))

Expand All @@ -395,7 +410,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
exp_chunk_no = exp_bit_no / 64;
exp_chunk_shift = exp_bit_no % 64;
{
red_table_idx_1 = expz[exp_chunk_no + 0 * (exp_digits + 1)];
red_table_idx_1 = expz[EXP_CHUNK(0)];
T = expz[EXP_CHUNK1(0)];

red_table_idx_1 >>= exp_chunk_shift;
Expand All @@ -408,7 +423,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
red_table_idx_1 &= table_idx_mask;
}
{
red_table_idx_2 = expz[exp_chunk_no + 1 * (exp_digits + 1)];
red_table_idx_2 = expz[EXP_CHUNK(1)];
T = expz[EXP_CHUNK1(1)];

red_table_idx_2 >>= exp_chunk_shift;
Expand All @@ -425,7 +440,7 @@ int rsaz_mod_exp_x2_ifma256(uint64_t *out,
(int)red_table_idx_1, (int)red_table_idx_2);
}

// Series of squaring
// The number of squarings is equal to the window size.
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
DAMS((uint64_t*)red_Y, (const uint64_t*)red_Y, m, k0);
Expand Down
101 changes: 77 additions & 24 deletions crypto/impl_dispatch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,36 +247,89 @@ TEST_F(ImplDispatchTest, SHA512) {
}
#endif // OPENSSL_AARCH64


#if defined(OPENSSL_X86_64) && !defined(MY_ASSEMBLER_IS_TOO_OLD_512AVX) && \
defined(RSAZ_512_ENABLED)

#include "test/file_test.h"

static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attr);

static bssl::UniquePtr<BIGNUM> GetBIGNUM(FileTest *t, const char *attr) {
std::string hex;
if (!t->GetAttribute(&hex, attr)) {
return nullptr;
}

BIGNUM *raw = NULL;
int size = BN_hex2bn(&raw, hex.c_str());
if (size != static_cast<int>(hex.size())) {
t->PrintLine("Could not decode '%s'.", hex.c_str());
return nullptr;
}

bssl::UniquePtr<BIGNUM> ret;
(&ret)->reset(raw);
return ret;
}

TEST_F(ImplDispatchTest, BN_mod_exp_mont_consttime_x2) {
AssertFunctionsHit(
FileTestGTest(
"crypto/fipsmodule/bn/test/mod_exp_x2_tests.txt",
[&](FileTest *t) {
AssertFunctionsHit(
{
{kFlag_RSAZ_mod_exp_avx512_x2,
is_x86_64_ &&
!is_assembler_too_old_avx512 &&
ifma_avx512},
{kFlag_RSAZ_mod_exp_avx512_x2,
is_x86_64_ &&
!is_assembler_too_old_avx512 &&
ifma_avx512},
},
[] {
uint64_t res1 = 0;
uint64_t base1 = 0;
uint64_t exp1 = 0;
uint64_t m1 = 0;
uint64_t rr1 = 0;
uint64_t k0_1 = 0;
uint64_t res2 = 0;
uint64_t base2 = 0;
uint64_t exp2 = 0;
uint64_t m2 = 0;
uint64_t rr2 = 0;
uint64_t k0_2 = 0;
int modlen = 0;

RSAZ_mod_exp_avx512_x2(&res1, &base1, &exp1, &m1, &rr1, k0_1,
&res2, &base2, &exp2, &m2, &rr2, k0_2,
modlen);
[&]() {
BN_CTX *ctx = BN_CTX_new();
BN_CTX_start(ctx);
bssl::UniquePtr<BIGNUM> a1 = GetBIGNUM(t, "A1");
bssl::UniquePtr<BIGNUM> e1 = GetBIGNUM(t, "E1");
bssl::UniquePtr<BIGNUM> m1 = GetBIGNUM(t, "M1");
bssl::UniquePtr<BIGNUM> mod_exp1 = GetBIGNUM(t, "ModExp1");
ASSERT_TRUE(a1);
ASSERT_TRUE(e1);
ASSERT_TRUE(m1);
ASSERT_TRUE(mod_exp1);

bssl::UniquePtr<BIGNUM> a2 = GetBIGNUM(t, "A2");
bssl::UniquePtr<BIGNUM> e2 = GetBIGNUM(t, "E2");
bssl::UniquePtr<BIGNUM> m2 = GetBIGNUM(t, "M2");
bssl::UniquePtr<BIGNUM> mod_exp2 = GetBIGNUM(t, "ModExp2");
ASSERT_TRUE(a2);
ASSERT_TRUE(e2);
ASSERT_TRUE(m2);
ASSERT_TRUE(mod_exp2);

bssl::UniquePtr<BIGNUM> ret1(BN_new());
ASSERT_TRUE(ret1);

bssl::UniquePtr<BIGNUM> ret2(BN_new());
ASSERT_TRUE(ret2);

ASSERT_TRUE(BN_nnmod(a1.get(), a1.get(), m1.get(), ctx));
ASSERT_TRUE(BN_nnmod(a2.get(), a2.get(), m2.get(), ctx));

BN_MONT_CTX *mont1 = NULL;
BN_MONT_CTX *mont2 = NULL;

ASSERT_TRUE(mont1 = BN_MONT_CTX_new());
ASSERT_TRUE(BN_MONT_CTX_set(mont1, m1.get(), ctx));
ASSERT_TRUE(mont2 = BN_MONT_CTX_new());
ASSERT_TRUE(BN_MONT_CTX_set(mont2, m2.get(), ctx));

BN_mod_exp_mont_consttime_x2(ret1.get(), a1.get(), e1.get(), m1.get(), mont1,
ret2.get(), a2.get(), e2.get(), m2.get(), mont2,
ctx);

BN_CTX_end(ctx);
BN_MONT_CTX_free(mont1);
BN_MONT_CTX_free(mont2);
});
});
}
#endif // OPENSSL_X86_64 && !MY_ASSEMBLER_IS_TOO_OLD_512AVX && RSAZ_512_ENABLED

Expand Down

0 comments on commit 92b9e3f

Please sign in to comment.