From 16ebb1f015506105e112d928d2c3874bf6523bef Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Fri, 30 Jun 2023 17:35:53 +0300 Subject: [PATCH 1/8] DX-67209 updated aes_encrypt/decrypt --- cpp/src/gandiva/encrypt_utils.cc | 45 ++++++++++++- cpp/src/gandiva/encrypt_utils_test.cc | 97 +++++++-------------------- cpp/src/gandiva/gdv_function_stubs.cc | 18 ++++- 3 files changed, 82 insertions(+), 78 deletions(-) diff --git a/cpp/src/gandiva/encrypt_utils.cc b/cpp/src/gandiva/encrypt_utils.cc index 7a274af792834..9dd8fc5b41d87 100644 --- a/cpp/src/gandiva/encrypt_utils.cc +++ b/cpp/src/gandiva/encrypt_utils.cc @@ -16,6 +16,7 @@ // under the License. #include "gandiva/encrypt_utils.h" +#include #include @@ -25,13 +26,33 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke unsigned char* cipher) { int32_t cipher_len = 0; int32_t len = 0; + int32_t current_key_len = static_cast(strlen(key)); EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = nullptr; if (!en_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for encryption"); } - if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr, + if (current_key_len != 16 && current_key_len != 24 && current_key_len != 32) { + throw std::runtime_error("invalid key length"); + } + + switch (current_key_len) { + case 16: + cipher_algo = EVP_aes_128_ecb(); + break; + case 24: + cipher_algo = EVP_aes_192_ecb(); + break; + case 32: + cipher_algo = EVP_aes_256_ecb(); + break; + default: + throw std::runtime_error("unsupported key length"); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for encryption"); } @@ -60,12 +81,32 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* int32_t plaintext_len = 0; int32_t len = 0; EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + int32_t current_key_len = static_cast(strlen(key)); + const EVP_CIPHER* cipher_algo = nullptr; if (!de_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for decryption"); } - if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr, + if (current_key_len != 16 && current_key_len != 24 && current_key_len != 32) { + throw std::runtime_error("invalid key length"); + } + + switch (current_key_len) { + case 16: + cipher_algo = EVP_aes_128_ecb(); + break; + case 24: + cipher_algo = EVP_aes_192_ecb(); + break; + case 32: + cipher_algo = EVP_aes_256_ecb(); + break; + default: + throw std::runtime_error("unsupported key length"); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for decryption"); } diff --git a/cpp/src/gandiva/encrypt_utils_test.cc b/cpp/src/gandiva/encrypt_utils_test.cc index 9c75515aedb67..6fb2ec7b28647 100644 --- a/cpp/src/gandiva/encrypt_utils_test.cc +++ b/cpp/src/gandiva/encrypt_utils_test.cc @@ -20,8 +20,8 @@ #include TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { - // 8 bytes key - auto* key = "1234abcd"; + // 16 bytes key + auto* key = "12345678abcdefgh"; auto* to_encrypt = "some test string"; auto to_encrypt_len = @@ -37,8 +37,8 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_1), decrypted_1_len)); - // 16 bytes key - key = "12345678abcdefgh"; + // 24 bytes key + key = "12345678abcdefgh12345678"; to_encrypt = "some\ntest\nstring"; to_encrypt_len = @@ -71,84 +71,35 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_3), decrypted_3_len)); - // 64 bytes key + // check exception + char* cipher = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E=="; + int32_t cipher_len = static_cast(strlen(reinterpret_cast(cipher))); + unsigned char plain_text[64]; + key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh"; to_encrypt = "New\ntest\nstring"; to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_4[64]; + ASSERT_THROW({ + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4); + }, std::runtime_error); - int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4); - - unsigned char decrypted_4[64]; - int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast(cipher_4), - cipher_4_len, key, decrypted_4); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_4), decrypted_4_len)); - - // 128 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh"; - to_encrypt = "A much more longer string then the previous one, but without newline"; - - to_encrypt_len = - static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_5[128]; - - int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5); - - unsigned char decrypted_5[128]; - int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast(cipher_5), - cipher_5_len, key, decrypted_5); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_5), decrypted_5_len)); - - // 192 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh"; - to_encrypt = - "A much more longer string then the previous one, but with \nnewline, pretty cool, " - "right?"; - - to_encrypt_len = - static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_6[256]; - - int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6); - - unsigned char decrypted_6[256]; - int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast(cipher_6), - cipher_6_len, key, decrypted_6); + ASSERT_THROW({ + gandiva::aes_decrypt(cipher, cipher_len, key, plain_text); + }, std::runtime_error); - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_6), decrypted_6_len)); - - // 256 bytes key - key = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456" - "78abcdefgh"; - to_encrypt = - "A much more longer string then the previous one, but with \nnewline, pretty cool, " - "right?"; + key = "12345678"; + to_encrypt = "New\ntest\nstring"; to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); - unsigned char cipher_7[256]; - - int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7); - - unsigned char decrypted_7[256]; - int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast(cipher_7), - cipher_7_len, key, decrypted_7); - - EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), - std::string(reinterpret_cast(decrypted_7), decrypted_7_len)); + unsigned char cipher_5[64]; + ASSERT_THROW({ + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5); + }, std::runtime_error); + ASSERT_THROW({ + gandiva::aes_decrypt(cipher, cipher_len, key, plain_text); + }, std::runtime_error); } diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index d2aeb883a3122..c55dc53dff3e5 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL -static constexpr int64_t kAesBlockSize = 16; // bytes - GANDIVA_EXPORT const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, @@ -318,6 +316,13 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l return ""; } + int64_t kAesBlockSize = 0; + if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { + kAesBlockSize = static_cast(key_data_len); + } else { + throw std::runtime_error("invalid key length"); + } + *out_len = static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); @@ -351,6 +356,13 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l return ""; } + int64_t kAesBlockSize = 0; + if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { + kAesBlockSize = static_cast(key_data_len); + } else { + throw std::runtime_error("invalid key length"); + } + *out_len = static_cast(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize)); char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); @@ -370,7 +382,7 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l *out_len = 0; return nullptr; } - + ret[*out_len] = '\0'; return ret; } From 37cfd5cab63f51e0d2baed797ef3350b6e319cf6 Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Fri, 30 Jun 2023 21:03:29 +0300 Subject: [PATCH 2/8] DX-67209 fixed tests --- cpp/src/gandiva/tests/projector_test.cc | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 462fae64393fd..e29b50035b1c1 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2818,22 +2818,14 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { int num_records = 4; + const char* key_16_bytes = "12345678abcdefgh"; + const char* key_24_bytes = "12345678abcdefgh12345678"; const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh"; - const char* key_64_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh"; - const char* key_128_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh"; - const char* key_256_bytes = - "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12" - "345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234" - "5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456" - "78abcdefgh"; auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted", "hyah\n"}, {true, true, true, true}); auto array_key = - MakeArrowArrayUtf8({key_32_bytes, key_64_bytes, key_128_bytes, key_256_bytes}, + MakeArrowArrayUtf8({key_16_bytes, key_24_bytes, key_32_bytes}, {true, true, true, true}); auto array_holder_en = MakeArrowArrayUtf8({"", "", "", ""}, {true, true, true, true}); From ce60e00496487924b950b9f335c4872a0c8f04a2 Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Sun, 2 Jul 2023 08:20:14 +0300 Subject: [PATCH 3/8] fixed tests --- cpp/src/gandiva/tests/projector_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index e29b50035b1c1..15dd9d355f2de 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2816,19 +2816,19 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { std::shared_ptr projector_en; ASSERT_OK(Projector::Make(schema, {encrypt_expr}, TestConfiguration(), &projector_en)); - int num_records = 4; + int num_records = 3; const char* key_16_bytes = "12345678abcdefgh"; const char* key_24_bytes = "12345678abcdefgh12345678"; const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh"; - auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted", "hyah\n"}, + auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted"}, {true, true, true, true}); auto array_key = MakeArrowArrayUtf8({key_16_bytes, key_24_bytes, key_32_bytes}, - {true, true, true, true}); + {true, true, true}); - auto array_holder_en = MakeArrowArrayUtf8({"", "", "", ""}, {true, true, true, true}); + auto array_holder_en = MakeArrowArrayUtf8({"", "", ""}, {true, true, true}); auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_data, array_key}); From 68ee2699d268505eb8ab62a43dc355ae0853c14b Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Sun, 2 Jul 2023 08:45:54 +0300 Subject: [PATCH 4/8] fixed tests --- cpp/src/gandiva/tests/projector_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 15dd9d355f2de..6817cad94da32 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2818,9 +2818,9 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { int num_records = 3; - const char* key_16_bytes = "12345678abcdefgh"; - const char* key_24_bytes = "12345678abcdefgh12345678"; - const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh"; + const char* key_16_bytes = "12345678abcdefg"; + const char* key_24_bytes = "12345678abcdefgh1234567"; + const char* key_32_bytes = "12345678abcdefgh12345678abcdefg"; auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted"}, {true, true, true, true}); From 3b3b257790b246f63b4b3969f8c74663e36fde97 Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Sun, 2 Jul 2023 10:25:17 +0300 Subject: [PATCH 5/8] DX-67209 chaged error handling, fixed tests --- cpp/src/gandiva/encrypt_utils.cc | 22 ++++++-------------- cpp/src/gandiva/encrypt_utils.h | 4 ++-- cpp/src/gandiva/encrypt_utils_test.cc | 27 +++++++++++++++---------- cpp/src/gandiva/gdv_function_stubs.cc | 12 +++++++---- cpp/src/gandiva/tests/projector_test.cc | 8 ++++---- 5 files changed, 36 insertions(+), 37 deletions(-) diff --git a/cpp/src/gandiva/encrypt_utils.cc b/cpp/src/gandiva/encrypt_utils.cc index 9dd8fc5b41d87..8a4370d97f15b 100644 --- a/cpp/src/gandiva/encrypt_utils.cc +++ b/cpp/src/gandiva/encrypt_utils.cc @@ -22,11 +22,10 @@ namespace gandiva { GANDIVA_EXPORT -int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - unsigned char* cipher) { +int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, unsigned char* cipher) { int32_t cipher_len = 0; int32_t len = 0; - int32_t current_key_len = static_cast(strlen(key)); EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); const EVP_CIPHER* cipher_algo = nullptr; @@ -34,11 +33,7 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke throw std::runtime_error("could not create a new evp cipher ctx for encryption"); } - if (current_key_len != 16 && current_key_len != 24 && current_key_len != 32) { - throw std::runtime_error("invalid key length"); - } - - switch (current_key_len) { + switch (key_len) { case 16: cipher_algo = EVP_aes_128_ecb(); break; @@ -76,23 +71,18 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke } GANDIVA_EXPORT -int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - unsigned char* plaintext) { +int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, unsigned char* plaintext) { int32_t plaintext_len = 0; int32_t len = 0; EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); - int32_t current_key_len = static_cast(strlen(key)); const EVP_CIPHER* cipher_algo = nullptr; if (!de_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for decryption"); } - if (current_key_len != 16 && current_key_len != 24 && current_key_len != 32) { - throw std::runtime_error("invalid key length"); - } - - switch (current_key_len) { + switch (key_len) { case 16: cipher_algo = EVP_aes_128_ecb(); break; diff --git a/cpp/src/gandiva/encrypt_utils.h b/cpp/src/gandiva/encrypt_utils.h index ea0e3580821a9..06e178fd65efa 100644 --- a/cpp/src/gandiva/encrypt_utils.h +++ b/cpp/src/gandiva/encrypt_utils.h @@ -28,13 +28,13 @@ namespace gandiva { **/ GANDIVA_EXPORT int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key, - unsigned char* cipher); + int32_t key_len, unsigned char* cipher); /** * Decrypt data using aes algorithm **/ GANDIVA_EXPORT int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, - unsigned char* plaintext); + int32_t key_len, unsigned char* plaintext); } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils_test.cc b/cpp/src/gandiva/encrypt_utils_test.cc index 6fb2ec7b28647..689f20ab03298 100644 --- a/cpp/src/gandiva/encrypt_utils_test.cc +++ b/cpp/src/gandiva/encrypt_utils_test.cc @@ -24,15 +24,16 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { auto* key = "12345678abcdefgh"; auto* to_encrypt = "some test string"; + auto key_len = static_cast(strlen(reinterpret_cast(key))); auto to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_1[64]; - int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1); + int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1); unsigned char decrypted_1[64]; int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast(cipher_1), - cipher_1_len, key, decrypted_1); + cipher_1_len, key, key_len, decrypted_1); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_1), decrypted_1_len)); @@ -41,15 +42,16 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { key = "12345678abcdefgh12345678"; to_encrypt = "some\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_2[64]; - int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2); + int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2); unsigned char decrypted_2[64]; int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast(cipher_2), - cipher_2_len, key, decrypted_2); + cipher_2_len, key, key_len, decrypted_2); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_2), decrypted_2_len)); @@ -58,48 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) { key = "12345678abcdefgh12345678abcdefgh"; to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_3[64]; - int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3); + int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3); unsigned char decrypted_3[64]; int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast(cipher_3), - cipher_3_len, key, decrypted_3); + cipher_3_len, key, key_len, decrypted_3); EXPECT_EQ(std::string(reinterpret_cast(to_encrypt), to_encrypt_len), std::string(reinterpret_cast(decrypted_3), decrypted_3_len)); // check exception - char* cipher = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E=="; + char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E=="; int32_t cipher_len = static_cast(strlen(reinterpret_cast(cipher))); unsigned char plain_text[64]; key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh"; to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_4[64]; ASSERT_THROW({ - gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4); + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4); }, std::runtime_error); ASSERT_THROW({ - gandiva::aes_decrypt(cipher, cipher_len, key, plain_text); + gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); }, std::runtime_error); key = "12345678"; to_encrypt = "New\ntest\nstring"; + key_len = static_cast(strlen(reinterpret_cast(key))); to_encrypt_len = static_cast(strlen(reinterpret_cast(to_encrypt))); unsigned char cipher_5[64]; ASSERT_THROW({ - gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5); + gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5); }, std::runtime_error); ASSERT_THROW({ - gandiva::aes_decrypt(cipher, cipher_len, key, plain_text); + gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text); }, std::runtime_error); } diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index c55dc53dff3e5..5146f7fa1990a 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -320,7 +320,9 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { kAesBlockSize = static_cast(key_data_len); } else { - throw std::runtime_error("invalid key length"); + gdv_fn_context_set_error_msg(context, "invalid key length"); + *out_len = 0; + return nullptr; } *out_len = @@ -335,7 +337,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l } try { - *out_len = gandiva::aes_encrypt(data, data_len, key_data, + *out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); @@ -360,7 +362,9 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) { kAesBlockSize = static_cast(key_data_len); } else { - throw std::runtime_error("invalid key length"); + gdv_fn_context_set_error_msg(context, "invalid key length"); + *out_len = 0; + return nullptr; } *out_len = @@ -375,7 +379,7 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l } try { - *out_len = gandiva::aes_decrypt(data, data_len, key_data, + *out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 6817cad94da32..d230170a4ead7 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2818,12 +2818,12 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { int num_records = 3; - const char* key_16_bytes = "12345678abcdefg"; - const char* key_24_bytes = "12345678abcdefgh1234567"; - const char* key_32_bytes = "12345678abcdefgh12345678abcdefg"; + const char* key_16_bytes = "12345678abcdefgh"; + const char* key_24_bytes = "12345678abcdefgh12345678"; + const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh"; auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted"}, - {true, true, true, true}); + {true, true, true}); auto array_key = MakeArrowArrayUtf8({key_16_bytes, key_24_bytes, key_32_bytes}, {true, true, true}); From e3822ffef65d7d2a770b0f7d957201168825bced Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Fri, 7 Jul 2023 20:56:24 +0300 Subject: [PATCH 6/8] move determining of cipher algo to separate method --- cpp/src/gandiva/encrypt_utils.cc | 45 +++++++++++--------------------- cpp/src/gandiva/encrypt_utils.h | 2 ++ 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/cpp/src/gandiva/encrypt_utils.cc b/cpp/src/gandiva/encrypt_utils.cc index 8a4370d97f15b..9dee10cfdb6ff 100644 --- a/cpp/src/gandiva/encrypt_utils.cc +++ b/cpp/src/gandiva/encrypt_utils.cc @@ -27,26 +27,12 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke int32_t cipher_len = 0; int32_t len = 0; EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); - const EVP_CIPHER* cipher_algo = nullptr; + const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); if (!en_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for encryption"); } - switch (key_len) { - case 16: - cipher_algo = EVP_aes_128_ecb(); - break; - case 24: - cipher_algo = EVP_aes_192_ecb(); - break; - case 32: - cipher_algo = EVP_aes_256_ecb(); - break; - default: - throw std::runtime_error("unsupported key length"); - } - if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for encryption"); @@ -76,26 +62,12 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* int32_t plaintext_len = 0; int32_t len = 0; EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); - const EVP_CIPHER* cipher_algo = nullptr; + const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len); if (!de_ctx) { throw std::runtime_error("could not create a new evp cipher ctx for decryption"); } - switch (key_len) { - case 16: - cipher_algo = EVP_aes_128_ecb(); - break; - case 24: - cipher_algo = EVP_aes_192_ecb(); - break; - case 32: - cipher_algo = EVP_aes_256_ecb(); - break; - default: - throw std::runtime_error("unsupported key length"); - } - if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, reinterpret_cast(key), nullptr)) { throw std::runtime_error("could not initialize evp cipher ctx for decryption"); @@ -118,4 +90,17 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* EVP_CIPHER_CTX_free(de_ctx); return plaintext_len; } + +const EVP_CIPHER* get_cipher_algo(int32_t key_length){ + switch (key_length) { + case 16: + return EVP_aes_128_ecb(); + case 24: + return EVP_aes_192_ecb(); + case 32: + return EVP_aes_256_ecb(); + default: + throw std::runtime_error("unsupported key length"); + } +} } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils.h b/cpp/src/gandiva/encrypt_utils.h index 06e178fd65efa..f02b029f01b64 100644 --- a/cpp/src/gandiva/encrypt_utils.h +++ b/cpp/src/gandiva/encrypt_utils.h @@ -37,4 +37,6 @@ GANDIVA_EXPORT int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key, int32_t key_len, unsigned char* plaintext); +const EVP_CIPHER* get_cipher_algo(int32_t key_length); + } // namespace gandiva From c92c8bc8c58ce8e63e38a87e4cf7fae0f7f961dd Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Mon, 10 Jul 2023 16:47:56 +0300 Subject: [PATCH 7/8] Add more unit tests --- cpp/src/gandiva/gdv_function_stubs_test.cc | 73 ++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index a8dfcd088ab17..c09609b638a2c 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1345,4 +1345,77 @@ TEST(TestGdvFnStubs, TestMask) { EXPECT_EQ(std::string(result, out_len), expected); } +TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len); + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len); + + EXPECT_EQ(std::string(data, data_len), + std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { + gandiva::ExecutionContext ctx; + std::string key24 = "12345678abcdefgh12345678"; + auto key24_len = static_cast(key24.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len); + + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len); + + EXPECT_EQ(std::string(data, data_len), + std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { + gandiva::ExecutionContext ctx; + std::string key32 = "12345678abcdefgh12345678abcdefgh"; + auto key32_len = static_cast(key32.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len); + + const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len); + + EXPECT_EQ(std::string(data, data_len), + std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { + gandiva::ExecutionContext ctx; + std::string key33 = "12345678abcdefgh12345678abcdefghb"; + auto key33_len = static_cast(key33.length()); + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + std::string cipher = "12345678abcdefgh12345678abcdefghb"; + auto cipher_len = static_cast(cipher.length()); + + gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("invalid key length")); + ctx.Reset(); + + gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("invalid key length")); + ctx.Reset(); +} } // namespace gandiva From 3e069245cba179392f7c0bbfa92bf3cd3af91cd6 Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Mon, 10 Jul 2023 17:06:37 +0300 Subject: [PATCH 8/8] fixed tests --- cpp/src/gandiva/gdv_function_stubs_test.cc | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index c09609b638a2c..552a972f73709 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1358,8 +1358,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len); const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len); - EXPECT_EQ(std::string(data, data_len), - std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { @@ -1376,8 +1375,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len); - EXPECT_EQ(std::string(data, data_len), - std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { @@ -1394,8 +1392,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len); - EXPECT_EQ(std::string(data, data_len), - std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) {