Skip to content

Commit

Permalink
DX-67209 updated aes_encrypt/decrypt (dremio#29)
Browse files Browse the repository at this point in the history
* DX-67209 updated aes_encrypt/decrypt
  • Loading branch information
xxlaykxx authored and lriggs committed Apr 25, 2024
1 parent 6a28035 commit b368a05
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 109 deletions.
28 changes: 22 additions & 6 deletions cpp/src/gandiva/encrypt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@
// under the License.

#include "gandiva/encrypt_utils.h"
#include <string.h>

#include <stdexcept>

namespace gandiva {
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher) {
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, unsigned char* cipher) {
int32_t cipher_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);

if (!en_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for encryption");
}

if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr,
if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for encryption");
}
Expand All @@ -55,17 +57,18 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke
}

GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext) {
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
int32_t key_len, unsigned char* plaintext) {
int32_t plaintext_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);

if (!de_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for decryption");
}

if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr,
if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for decryption");
}
Expand All @@ -87,4 +90,17 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char*
EVP_CIPHER_CTX_free(de_ctx);
return plaintext_len;
}

const EVP_CIPHER* get_cipher_algo(int32_t key_length){
switch (key_length) {
case 16:
return EVP_aes_128_ecb();
case 24:
return EVP_aes_192_ecb();
case 32:
return EVP_aes_256_ecb();
default:
throw std::runtime_error("unsupported key length");
}
}
} // namespace gandiva
6 changes: 4 additions & 2 deletions cpp/src/gandiva/encrypt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ namespace gandiva {
**/
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher);
int32_t key_len, unsigned char* cipher);

/**
* Decrypt data using aes algorithm
**/
GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext);
int32_t key_len, unsigned char* plaintext);

const EVP_CIPHER* get_cipher_algo(int32_t key_length);

} // namespace gandiva
114 changes: 35 additions & 79 deletions cpp/src/gandiva/encrypt_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,38 @@
#include <gtest/gtest.h>

TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
// 8 bytes key
auto* key = "1234abcd";
// 16 bytes key
auto* key = "12345678abcdefgh";
auto* to_encrypt = "some test string";

auto key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
auto to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_1[64];

int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1);
int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1);

unsigned char decrypted_1[64];
int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_1),
cipher_1_len, key, decrypted_1);
cipher_1_len, key, key_len, decrypted_1);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_1), decrypted_1_len));

// 16 bytes key
key = "12345678abcdefgh";
// 24 bytes key
key = "12345678abcdefgh12345678";
to_encrypt = "some\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_2[64];

int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2);
int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2);

unsigned char decrypted_2[64];
int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_2),
cipher_2_len, key, decrypted_2);
cipher_2_len, key, key_len, decrypted_2);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_2), decrypted_2_len));
Expand All @@ -58,97 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
key = "12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_3[64];

int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3);
int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3);

unsigned char decrypted_3[64];
int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_3),
cipher_3_len, key, decrypted_3);
cipher_3_len, key, key_len, decrypted_3);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_3), decrypted_3_len));

// 64 bytes key
// check exception
char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E==";
int32_t cipher_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(cipher)));
unsigned char plain_text[64];

key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_4[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4);
}, std::runtime_error);

int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4);

unsigned char decrypted_4[64];
int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_4),
cipher_4_len, key, decrypted_4);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_4), decrypted_4_len));

// 128 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "A much more longer string then the previous one, but without newline";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_5[128];

int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5);

unsigned char decrypted_5[128];
int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_5),
cipher_5_len, key, decrypted_5);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_5), decrypted_5_len));

// 192 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_6[256];

int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6);

unsigned char decrypted_6[256];
int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_6),
cipher_6_len, key, decrypted_6);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_6), decrypted_6_len));

// 256 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456"
"78abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";
key = "12345678";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_7[256];

int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7);

unsigned char decrypted_7[256];
int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_7),
cipher_7_len, key, decrypted_7);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_7), decrypted_7_len));
unsigned char cipher_5[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5);
}, std::runtime_error);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);
}
26 changes: 21 additions & 5 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8)
#undef GDV_FN_CAST_VARCHAR_INTEGER
#undef GDV_FN_CAST_VARCHAR_REAL

static constexpr int64_t kAesBlockSize = 16; // bytes

GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len,
const char* key_data, int32_t key_data_len,
Expand All @@ -318,6 +316,15 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -329,7 +336,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_encrypt(data, data_len, key_data,
*out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
Expand All @@ -349,6 +356,15 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -360,13 +376,13 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_decrypt(data, data_len, key_data,
*out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
return nullptr;
}

ret[*out_len] = '\0';
return ret;
}

Expand Down
70 changes: 70 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1345,4 +1345,74 @@ TEST(TestGdvFnStubs, TestMask) {
EXPECT_EQ(std::string(result, out_len), expected);
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) {
gandiva::ExecutionContext ctx;
std::string key16 = "12345678abcdefgh";
auto key16_len = static_cast<int32_t>(key16.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len);
const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) {
gandiva::ExecutionContext ctx;
std::string key24 = "12345678abcdefgh12345678";
auto key24_len = static_cast<int32_t>(key24.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len);

const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) {
gandiva::ExecutionContext ctx;
std::string key32 = "12345678abcdefgh12345678abcdefgh";
auto key32_len = static_cast<int32_t>(key32.length());
int32_t cipher_len = 0;
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len);

const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len);

EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
}

TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) {
gandiva::ExecutionContext ctx;
std::string key33 = "12345678abcdefgh12345678abcdefghb";
auto key33_len = static_cast<int32_t>(key33.length());
int32_t decrypted_len = 0;
std::string data = "test string";
auto data_len = static_cast<int32_t>(data.length());
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
std::string cipher = "12345678abcdefgh12345678abcdefghb";
auto cipher_len = static_cast<int32_t>(cipher.length());

gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("invalid key length"));
ctx.Reset();

gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("invalid key length"));
ctx.Reset();
}
} // namespace gandiva
Loading

0 comments on commit b368a05

Please sign in to comment.