Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cryptomb: reduce memory copy in ECDSA #33201

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,65 @@ ssl_private_key_result_t ecdsaPrivateKeyDecrypt(SSL*, uint8_t*, size_t*, size_t,
return ssl_private_key_failure;
}

ssl_private_key_result_t ecdsaPrivateKeyCompleteInternal(CryptoMbPrivateKeyConnection* ops,
uint8_t* out, size_t* out_len,
size_t max_out) {
if (ops == nullptr) {
return ssl_private_key_failure;
}

// Check if the MB operation is ready yet. This can happen if someone calls
// the top-level SSL function too early. The op status is only set from this
// thread.
if (ops->mb_ctx_->getStatus() == RequestStatus::Retry) {
return ssl_private_key_retry;
}

// If this point is reached, the MB processing must be complete.

// See if the operation failed.
if (ops->mb_ctx_->getStatus() != RequestStatus::Success) {
ops->logWarnMsg("private key operation failed.");
return ssl_private_key_failure;
}

CryptoMbEcdsaContextSharedPtr mb_ctx =
std::static_pointer_cast<CryptoMbEcdsaContext>(ops->mb_ctx_);
if (mb_ctx->sig_len_ > max_out) {
return ssl_private_key_failure;
}

ECDSA_SIG* sig = ECDSA_SIG_new();
if (sig == nullptr) {
return ssl_private_key_failure;
}
BIGNUM* sig_r = BN_bin2bn(mb_ctx->sig_r_, 32, nullptr);
BIGNUM* sig_s = BN_bin2bn(mb_ctx->sig_s_, 32, nullptr);
ECDSA_SIG_set0(sig, sig_r, sig_s);

// Marshal signature into out.
CBB cbb;
if (!CBB_init_fixed(&cbb, out, mb_ctx->sig_len_) || !ECDSA_SIG_marshal(&cbb, sig) ||
!CBB_finish(&cbb, nullptr, out_len)) {
CBB_cleanup(&cbb);
ECDSA_SIG_free(sig);
return ssl_private_key_failure;
}

ECDSA_SIG_free(sig);

return ssl_private_key_success;
}

ssl_private_key_result_t ecdsaPrivateKeyComplete(SSL* ssl, uint8_t* out, size_t* out_len,
size_t max_out) {
return ssl == nullptr ? ssl_private_key_failure
: ecdsaPrivateKeyCompleteInternal(
static_cast<CryptoMbPrivateKeyConnection*>(SSL_get_ex_data(
ssl, CryptoMbPrivateKeyMethodProvider::connectionIndex())),
out, out_len, max_out);
}

ssl_private_key_result_t rsaPrivateKeySignInternal(CryptoMbPrivateKeyConnection* ops, uint8_t*,
size_t*, size_t, uint16_t signature_algorithm,
const uint8_t* in, size_t in_len) {
Expand Down Expand Up @@ -337,8 +396,9 @@ ssl_private_key_result_t rsaPrivateKeyDecrypt(SSL* ssl, uint8_t* out, size_t* ou
out, out_len, max_out, in, in_len);
}

ssl_private_key_result_t privateKeyCompleteInternal(CryptoMbPrivateKeyConnection* ops, uint8_t* out,
size_t* out_len, size_t max_out) {
ssl_private_key_result_t rsaPrivateKeyCompleteInternal(CryptoMbPrivateKeyConnection* ops,
uint8_t* out, size_t* out_len,
size_t max_out) {
if (ops == nullptr) {
return ssl_private_key_failure;
}
Expand All @@ -358,21 +418,22 @@ ssl_private_key_result_t privateKeyCompleteInternal(CryptoMbPrivateKeyConnection
return ssl_private_key_failure;
}

*out_len = ops->mb_ctx_->out_len_;
CryptoMbRsaContextSharedPtr mb_ctx = std::static_pointer_cast<CryptoMbRsaContext>(ops->mb_ctx_);
*out_len = mb_ctx->out_len_;

if (*out_len > max_out) {
return ssl_private_key_failure;
}

memcpy(out, ops->mb_ctx_->out_buf_, *out_len); // NOLINT(safe-memcpy)
memcpy(out, mb_ctx->out_buf_, *out_len); // NOLINT(safe-memcpy)

return ssl_private_key_success;
}

ssl_private_key_result_t privateKeyComplete(SSL* ssl, uint8_t* out, size_t* out_len,
size_t max_out) {
ssl_private_key_result_t rsaPrivateKeyComplete(SSL* ssl, uint8_t* out, size_t* out_len,
size_t max_out) {
return ssl == nullptr ? ssl_private_key_failure
: privateKeyCompleteInternal(
: rsaPrivateKeyCompleteInternal(
static_cast<CryptoMbPrivateKeyConnection*>(SSL_get_ex_data(
ssl, CryptoMbPrivateKeyMethodProvider::connectionIndex())),
out, out_len, max_out);
Expand All @@ -381,9 +442,15 @@ ssl_private_key_result_t privateKeyComplete(SSL* ssl, uint8_t* out, size_t* out_
} // namespace

// External linking, meant for testing without SSL context.
ssl_private_key_result_t privateKeyCompleteForTest(CryptoMbPrivateKeyConnection* ops, uint8_t* out,
size_t* out_len, size_t max_out) {
return privateKeyCompleteInternal(ops, out, out_len, max_out);
ssl_private_key_result_t ecdsaPrivateKeyCompleteForTest(CryptoMbPrivateKeyConnection* ops,
uint8_t* out, size_t* out_len,
size_t max_out) {
return ecdsaPrivateKeyCompleteInternal(ops, out, out_len, max_out);
}
ssl_private_key_result_t rsaPrivateKeyCompleteForTest(CryptoMbPrivateKeyConnection* ops,
uint8_t* out, size_t* out_len,
size_t max_out) {
return rsaPrivateKeyCompleteInternal(ops, out, out_len, max_out);
}
ssl_private_key_result_t ecdsaPrivateKeySignForTest(CryptoMbPrivateKeyConnection* ops, uint8_t* out,
size_t* out_len, size_t max_out,
Expand Down Expand Up @@ -519,12 +586,8 @@ void CryptoMbQueue::processRsaRequests() {
}

void CryptoMbQueue::processEcdsaRequests() {
uint8_t sig_r[MULTIBUFF_BATCH][32];
uint8_t sig_s[MULTIBUFF_BATCH][32];
uint8_t* pa_sig_r[MULTIBUFF_BATCH] = {sig_r[0], sig_r[1], sig_r[2], sig_r[3],
sig_r[4], sig_r[5], sig_r[6], sig_r[7]};
uint8_t* pa_sig_s[MULTIBUFF_BATCH] = {sig_s[0], sig_s[1], sig_s[2], sig_s[3],
sig_s[4], sig_s[5], sig_s[6], sig_s[7]};
uint8_t* pa_sig_r[MULTIBUFF_BATCH] = {};
uint8_t* pa_sig_s[MULTIBUFF_BATCH] = {};
const unsigned char* digest[MULTIBUFF_BATCH] = {nullptr};
const BIGNUM* eph_key[MULTIBUFF_BATCH] = {nullptr};
const BIGNUM* priv_key[MULTIBUFF_BATCH] = {nullptr};
Expand All @@ -533,6 +596,8 @@ void CryptoMbQueue::processEcdsaRequests() {
for (unsigned req_num = 0; req_num < request_queue_.size(); req_num++) {
CryptoMbEcdsaContextSharedPtr mb_ctx =
std::static_pointer_cast<CryptoMbEcdsaContext>(request_queue_[req_num]);
pa_sig_r[req_num] = mb_ctx->sig_r_;
pa_sig_s[req_num] = mb_ctx->sig_s_;
digest[req_num] = mb_ctx->in_buf_.get();
eph_key[req_num] = mb_ctx->k_;
priv_key[req_num] = mb_ctx->priv_key_;
Expand All @@ -551,11 +616,7 @@ void CryptoMbQueue::processEcdsaRequests() {
enum RequestStatus ctx_status;
if (ipp_->mbxGetSts(ecdsa_sts, req_num)) {
ENVOY_LOG(debug, "Multibuffer ECDSA request {} success", req_num);
if (postprocessEcdsaRequest(mb_ctx, pa_sig_r[req_num], pa_sig_s[req_num])) {
status[req_num] = RequestStatus::Success;
} else {
status[req_num] = RequestStatus::Error;
}
status[req_num] = RequestStatus::Success;
} else {
ENVOY_LOG(debug, "Multibuffer ECDSA request {} failure", req_num);
status[req_num] = RequestStatus::Error;
Expand All @@ -569,29 +630,6 @@ void CryptoMbQueue::processEcdsaRequests() {
}
}

bool CryptoMbQueue::postprocessEcdsaRequest(CryptoMbEcdsaContextSharedPtr mb_ctx,
const uint8_t* pa_sig_r, const uint8_t* pa_sig_s) {
ECDSA_SIG* sig = ECDSA_SIG_new();
if (sig == nullptr) {
return false;
}
BIGNUM* sig_r = BN_bin2bn(pa_sig_r, 32, nullptr);
BIGNUM* sig_s = BN_bin2bn(pa_sig_s, 32, nullptr);
ECDSA_SIG_set0(sig, sig_r, sig_s);

// Marshal signature into out_buf_.
CBB cbb;
if (!CBB_init_fixed(&cbb, mb_ctx->out_buf_, mb_ctx->sig_len_) || !ECDSA_SIG_marshal(&cbb, sig) ||
!CBB_finish(&cbb, nullptr, &mb_ctx->out_len_)) {
CBB_cleanup(&cbb);
ECDSA_SIG_free(sig);
return false;
}

ECDSA_SIG_free(sig);
return true;
}

CryptoMbPrivateKeyConnection::CryptoMbPrivateKeyConnection(Ssl::PrivateKeyConnectionCallbacks& cb,
Event::Dispatcher& dispatcher,
bssl::UniquePtr<EVP_PKEY> pkey,
Expand Down Expand Up @@ -678,7 +716,7 @@ CryptoMbPrivateKeyMethodProvider::CryptoMbPrivateKeyMethodProvider(

method_->sign = rsaPrivateKeySign;
method_->decrypt = rsaPrivateKeyDecrypt;
method_->complete = privateKeyComplete;
method_->complete = rsaPrivateKeyComplete;

RSA* rsa = EVP_PKEY_get0_RSA(pkey.get());
switch (RSA_bits(rsa)) {
Expand Down Expand Up @@ -722,7 +760,7 @@ CryptoMbPrivateKeyMethodProvider::CryptoMbPrivateKeyMethodProvider(

method_->sign = ecdsaPrivateKeySign;
method_->decrypt = ecdsaPrivateKeyDecrypt;
method_->complete = privateKeyComplete;
method_->complete = ecdsaPrivateKeyComplete;

const EC_GROUP* ecdsa_group = EC_KEY_get0_group(EVP_PKEY_get0_EC_KEY(pkey.get()));
if (ecdsa_group == nullptr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class CryptoMbContext {
enum RequestStatus getStatus() { return status_; }
void scheduleCallback(enum RequestStatus status);

// Buffer length is the same as the max signature length (4096 bits = 512 bytes)
unsigned char out_buf_[MAX_SIGNATURE_SIZE];
// The real length of the signature.
size_t out_len_{};
// Incoming data buffer.
std::unique_ptr<uint8_t[]> in_buf_;

Expand Down Expand Up @@ -75,6 +71,10 @@ class CryptoMbEcdsaContext : public CryptoMbContext {
// BoringSSL ECDSA key structure, so not wrapped in smart pointers.
const BIGNUM* priv_key_{};
size_t sig_len_{};

// ECDSA signature.
uint8_t sig_r_[32]{};
uint8_t sig_s_[32]{};
};

// CryptoMbRsaContext is a CryptoMbContext which holds the extra RSA parameters and has
Expand Down Expand Up @@ -103,6 +103,11 @@ class CryptoMbRsaContext : public CryptoMbContext {

// Buffer for `Lenstra` check.
unsigned char lenstra_to_[MAX_SIGNATURE_SIZE];

// Buffer length is the same as the max signature length (4096 bits = 512 bytes)
unsigned char out_buf_[MAX_SIGNATURE_SIZE];
// The real length of the signature.
size_t out_len_{};
};

using CryptoMbContextSharedPtr = std::shared_ptr<CryptoMbContext>;
Expand All @@ -123,8 +128,6 @@ class CryptoMbQueue : public Logger::Loggable<Logger::Id::connection> {
void processRequests();
void processRsaRequests();
void processEcdsaRequests();
bool postprocessEcdsaRequest(CryptoMbEcdsaContextSharedPtr mb_ctx, const uint8_t* sign_r,
const uint8_t* sign_s);
void startTimer();
void stopTimer();

Expand Down
15 changes: 15 additions & 0 deletions contrib/cryptomb/private_key_providers/test/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load(
"//bazel:envoy_build_system.bzl",
"envoy_cc_benchmark_binary",
"envoy_cc_test",
"envoy_cc_test_library",
"envoy_contrib_package",
Expand Down Expand Up @@ -82,3 +83,17 @@ envoy_cc_test(
"//test/test_common:utility_lib",
],
)

envoy_cc_benchmark_binary(
name = "speed_test",
srcs = ["speed_test.cc"],
external_deps = [
"benchmark",
"ssl",
],
deps = [
"//contrib/cryptomb/private_key_providers/source:ipp_crypto_wrapper_lib",
"//source/common/common:assert_lib",
"//source/common/common:utility_lib",
],
)
Loading
Loading