diff --git a/cpp/src/gandiva/encrypt_utils.cc b/cpp/src/gandiva/encrypt_utils.cc index 7a274af792834..9dee10cfdb6ff 100644 --- a/cpp/src/gandiva/encrypt_utils.cc +++ b/cpp/src/gandiva/encrypt_utils.cc @@ -16,22 +16,24 @@ // under the License. #include "gandiva/encrypt_utils.h" +#include #include namespace gandiva { GANDIVA_EXPORT -int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - unsigned char* cipher) { +int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, unsigned char* cipher) { int32_t cipher_len = 0; int32_t len = 0; EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); if (!en_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for encryption"); } - if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr, + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for encryption"); } @@ -55,17 +57,18 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke } GANDIVA_EXPORT -int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - unsigned char* plaintext) { +int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, unsigned char* plaintext) { int32_t plaintext_len = 0; int32_t len = 0; EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); if (!de_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for decryption"); } - if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr, + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for decryption"); } @@ -87,4 +90,17 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* EVP_CIPHER_CTX_free(de_ctx); return plaintext_len; } + +const EVP_CIPHER* get_cipher_algo(int32_t key_length){ + switch (key_length) { + case 16: + return EVP_aes_128_ecb(); + case 24: + return EVP_aes_192_ecb(); + case 32: + return EVP_aes_256_ecb(); + default: + throw std::runtime_error("unsupported key length"); + } +} } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils.h b/cpp/src/gandiva/encrypt_utils.h index ea0e3580821a9..f02b029f01b64 100644 --- a/cpp/src/gandiva/encrypt_utils.h +++ b/cpp/src/gandiva/encrypt_utils.h @@ -28,13 +28,15 @@ namespace gandiva { **/ GANDIVA_EXPORT int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - unsigned char* cipher); + int32_t key_len, unsigned char* cipher); /** * Decrypt data using aes algorithm **/ GANDIVA_EXPORT int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - unsigned char* plaintext); + int32_t key_len, unsigned char* plaintext); + +const EVP_CIPHER* get_cipher_algo(int32_t key_length); } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils_test.cc b/cpp/src/gandiva/encrypt_utils_test.cc index 9c75515aedb67..689f20ab03298 100644 --- a/cpp/src/gandiva/encrypt_utils_test.cc +++ b/cpp/src/gandiva/encrypt_utils_test.cc @@ -20,36 +20,38 @@ #include TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { - // 8 bytes key - auto* key = "1234abcd"; + // 16 bytes key + auto* key = "12345678abcdefgh"; auto* to_encrypt = "some test string"; + auto key_len = static_cast(strlen(reinterpret_cast(key))); auto to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_1[64]; - int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1); + int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1); unsigned char decrypted_1[64]; int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast(cipher_1), - cipher_1_len, key, decrypted_1); + cipher_1_len, key, key_len, decrypted_1); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_1), decrypted_1_len)); - // 16 bytes key - key = "12345678abcdefgh"; + // 24 bytes key + key = "12345678abcdefgh12345678"; to_encrypt = "some\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_2[64]; - int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2); + int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2); unsigned char decrypted_2[64]; int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast(cipher_2), - cipher_2_len, key, decrypted_2); + cipher_2_len, key, key_len, decrypted_2); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_2), decrypted_2_len)); @@ -58,97 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { key = "12345678abcdefgh12345678abcdefgh"; to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_3[64]; - int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3); + int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3); unsigned char decrypted_3[64]; int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast(cipher_3), - cipher_3_len, key, decrypted_3); + cipher_3_len, key, key_len, decrypted_3); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_3), decrypted_3_len)); - // 64 bytes key + // check exception + char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E=="; + int32_t cipher_len = static_cast(strlen(reinterpret_cast(cipher))); + unsigned char plain_text[64]; + key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh"; to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_4[64]; + ASSERT_THROW({ + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4); + }, std::runtime_error); - int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4); - - unsigned char decrypted_4[64]; - int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast(cipher_4), - cipher_4_len, key, decrypted_4); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_4), decrypted_4_len)); - - // 128 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh"; - to_encrypt = "A much more longer string then the previous one, but without newline"; - - to_encrypt_len = - static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_5[128]; - - int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5); - - unsigned char decrypted_5[128]; - int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast(cipher_5), - cipher_5_len, key, decrypted_5); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_5), decrypted_5_len)); - - // 192 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh"; - to_encrypt = - "A much more longer string then the previous one, but with \nnewline, pretty cool, " - "right?"; - - to_encrypt_len = - static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_6[256]; - - int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6); - - unsigned char decrypted_6[256]; - int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast(cipher_6), - cipher_6_len, key, decrypted_6); + ASSERT_THROW({ + gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); + }, std::runtime_error); - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_6), decrypted_6_len)); - - // 256 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456" - "78abcdefgh"; - to_encrypt = - "A much more longer string then the previous one, but with \nnewline, pretty cool, " - "right?"; + key = "12345678"; + to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_7[256]; - - int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7); - - unsigned char decrypted_7[256]; - int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast(cipher_7), - cipher_7_len, key, decrypted_7); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_7), decrypted_7_len)); + unsigned char cipher_5[64]; + ASSERT_THROW({ + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5); + }, std::runtime_error); + ASSERT_THROW({ + gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); + }, std::runtime_error); } diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index bcef954a473ea..8c6d32162cbaf 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL -static constexpr int64_t kAesBlockSize = 16; // bytes - GANDIVA_EXPORT const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, @@ -318,6 +316,15 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l return ""; } + int64_t kAesBlockSize = 0; + if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { + kAesBlockSize = static_cast(key_data_len); + } else { + gdv_fn_context_set_error_msg(context, "invalid key length"); + *out_len = 0; + return nullptr; + } + *out_len = static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); @@ -329,7 +336,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l } try { - *out_len = gandiva::aes_encrypt(data, data_len, key_data, + *out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); @@ -349,6 +356,15 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l return ""; } + int64_t kAesBlockSize = 0; + if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { + kAesBlockSize = static_cast(key_data_len); + } else { + gdv_fn_context_set_error_msg(context, "invalid key length"); + *out_len = 0; + return nullptr; + } + *out_len = static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); @@ -360,13 +376,13 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l } try { - *out_len = gandiva::aes_decrypt(data, data_len, key_data, + *out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); return nullptr; } - + ret[*out_len] = '\0'; return ret; } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 3e403828a4cce..134f7dcd27dde 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1345,4 +1345,74 @@ TEST(TestGdvFnStubs, TestMask) { EXPECT_EQ(std::string(result, out_len), expected); } +TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len); + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len); + + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { + gandiva::ExecutionContext ctx; + std::string key24 = "12345678abcdefgh12345678"; + auto key24_len = static_cast(key24.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len); + + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len); + + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { + gandiva::ExecutionContext ctx; + std::string key32 = "12345678abcdefgh12345678abcdefgh"; + auto key32_len = static_cast(key32.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len); + + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len); + + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { + gandiva::ExecutionContext ctx; + std::string key33 = "12345678abcdefgh12345678abcdefghb"; + auto key33_len = static_cast(key33.length()); + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + std::string cipher = "12345678abcdefgh12345678abcdefghb"; + auto cipher_len = static_cast(cipher.length()); + + gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("invalid key length")); + ctx.Reset(); + + gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("invalid key length")); + ctx.Reset(); +} } // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 59eeb3d92f19a..a6ad07830d121 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2817,27 +2817,19 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { std::shared_ptr projector_en; ASSERT_OK(Projector::Make(schema, {encrypt_expr}, TestConfiguration(), &projector_en)); - int num_records = 4; + int num_records = 3; + const char* key_16_bytes = "12345678abcdefgh"; + const char* key_24_bytes = "12345678abcdefgh12345678"; const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh"; - const char* key_64_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh"; - const char* key_128_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh"; - const char* key_256_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456" - "78abcdefgh"; - - auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted", "hyah\n"}, - {true, true, true, true}); + + auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted"}, + {true, true, true}); auto array_key = - MakeArrowArrayUtf8({key_32_bytes, key_64_bytes, key_128_bytes, key_256_bytes}, - {true, true, true, true}); + MakeArrowArrayUtf8({key_16_bytes, key_24_bytes, key_32_bytes}, + {true, true, true}); - auto array_holder_en = MakeArrowArrayUtf8({"", "", "", ""}, {true, true, true, true}); + auto array_holder_en = MakeArrowArrayUtf8({"", "", ""}, {true, true, true}); auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_data, array_key});