Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DX-67209 updated aes_encrypt/decrypt #29

Merged
merged 8 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions cpp/src/gandiva/encrypt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,38 @@
// under the License.

#include "gandiva/encrypt_utils.h"
#include <string.h>

#include <stdexcept>

namespace gandiva {
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher) {
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, unsigned char* cipher) {
int32_t cipher_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = nullptr;

if (!en_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for encryption");
}

if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr,
switch (key_len) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be combined with the other similar code below into a helper function.

case 16:
cipher_algo = EVP_aes_128_ecb();
break;
case 24:
cipher_algo = EVP_aes_192_ecb();
break;
case 32:
cipher_algo = EVP_aes_256_ecb();
break;
default:
throw std::runtime_error("unsupported key length");
}

if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for encryption");
}
Expand All @@ -55,17 +71,32 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke
}

GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext) {
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
int32_t key_len, unsigned char* plaintext) {
int32_t plaintext_len = 0;
int32_t len = 0;
EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new();
const EVP_CIPHER* cipher_algo = nullptr;

if (!de_ctx) {
throw std::runtime_error("could not create a new evp cipher ctx for decryption");
}

if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr,
switch (key_len) {
case 16:
cipher_algo = EVP_aes_128_ecb();
break;
case 24:
cipher_algo = EVP_aes_192_ecb();
break;
case 32:
cipher_algo = EVP_aes_256_ecb();
break;
default:
throw std::runtime_error("unsupported key length");
}

if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr,
reinterpret_cast<const unsigned char*>(key), nullptr)) {
throw std::runtime_error("could not initialize evp cipher ctx for decryption");
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/encrypt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ namespace gandiva {
**/
GANDIVA_EXPORT
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
unsigned char* cipher);
int32_t key_len, unsigned char* cipher);

/**
* Decrypt data using aes algorithm
**/
GANDIVA_EXPORT
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
unsigned char* plaintext);
int32_t key_len, unsigned char* plaintext);

} // namespace gandiva
114 changes: 35 additions & 79 deletions cpp/src/gandiva/encrypt_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,38 @@
#include <gtest/gtest.h>

TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
// 8 bytes key
auto* key = "1234abcd";
// 16 bytes key
auto* key = "12345678abcdefgh";
auto* to_encrypt = "some test string";

auto key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
auto to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_1[64];

int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1);
int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1);

unsigned char decrypted_1[64];
int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_1),
cipher_1_len, key, decrypted_1);
cipher_1_len, key, key_len, decrypted_1);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_1), decrypted_1_len));

// 16 bytes key
key = "12345678abcdefgh";
// 24 bytes key
key = "12345678abcdefgh12345678";
to_encrypt = "some\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_2[64];

int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2);
int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2);

unsigned char decrypted_2[64];
int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_2),
cipher_2_len, key, decrypted_2);
cipher_2_len, key, key_len, decrypted_2);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_2), decrypted_2_len));
Expand All @@ -58,97 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
key = "12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_3[64];

int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3);
int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3);

unsigned char decrypted_3[64];
int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_3),
cipher_3_len, key, decrypted_3);
cipher_3_len, key, key_len, decrypted_3);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_3), decrypted_3_len));

// 64 bytes key
// check exception
char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E==";
int32_t cipher_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(cipher)));
unsigned char plain_text[64];

key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_4[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4);
}, std::runtime_error);

int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4);

unsigned char decrypted_4[64];
int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_4),
cipher_4_len, key, decrypted_4);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_4), decrypted_4_len));

// 128 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh";
to_encrypt = "A much more longer string then the previous one, but without newline";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_5[128];

int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5);

unsigned char decrypted_5[128];
int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_5),
cipher_5_len, key, decrypted_5);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_5), decrypted_5_len));

// 192 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";

to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_6[256];

int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6);

unsigned char decrypted_6[256];
int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_6),
cipher_6_len, key, decrypted_6);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_6), decrypted_6_len));

// 256 bytes key
key =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456"
"78abcdefgh";
to_encrypt =
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
"right?";
key = "12345678";
to_encrypt = "New\ntest\nstring";

key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
to_encrypt_len =
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
unsigned char cipher_7[256];

int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7);

unsigned char decrypted_7[256];
int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_7),
cipher_7_len, key, decrypted_7);

EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
std::string(reinterpret_cast<const char*>(decrypted_7), decrypted_7_len));
unsigned char cipher_5[64];
ASSERT_THROW({
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5);
}, std::runtime_error);
ASSERT_THROW({
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
}, std::runtime_error);
}
26 changes: 21 additions & 5 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8)
#undef GDV_FN_CAST_VARCHAR_INTEGER
#undef GDV_FN_CAST_VARCHAR_REAL

static constexpr int64_t kAesBlockSize = 16; // bytes

GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len,
const char* key_data, int32_t key_data_len,
Expand All @@ -318,6 +316,15 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
Comment on lines +320 to +325
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't the helper functions already check the length? I think it would be better to just do the check in one place since the code would be easier to read and better tested. It doesn't look like this path is being tested since the unit tests call the helper functions.

}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -330,7 +337,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_encrypt(data, data_len, key_data,
*out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
Expand All @@ -351,6 +358,15 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
return "";
}

int64_t kAesBlockSize = 0;
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
kAesBlockSize = static_cast<int64_t>(key_data_len);
} else {
gdv_fn_context_set_error_msg(context, "invalid key length");
*out_len = 0;
return nullptr;
}

*out_len =
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
Expand All @@ -363,14 +379,14 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
}

try {
*out_len = gandiva::aes_decrypt(data, data_len, key_data,
*out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len,
reinterpret_cast<unsigned char*>(ret));
} catch (const std::runtime_error& e) {
gdv_fn_context_set_error_msg(context, e.what());
*out_len = 0;
return nullptr;
}

ret[*out_len] = '\0';
return ret;
}

Expand Down
26 changes: 9 additions & 17 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2816,27 +2816,19 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) {
std::shared_ptr<Projector> projector_en;
ASSERT_OK(Projector::Make(schema, {encrypt_expr}, TestConfiguration(), &projector_en));

int num_records = 4;
int num_records = 3;

const char* key_16_bytes = "12345678abcdefgh";
const char* key_24_bytes = "12345678abcdefgh12345678";
const char* key_32_bytes = "12345678abcdefgh12345678abcdefgh";
const char* key_64_bytes =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh";
const char* key_128_bytes =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh";
const char* key_256_bytes =
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
"5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456"
"78abcdefgh";

auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted", "hyah\n"},
{true, true, true, true});

auto array_data = MakeArrowArrayUtf8({"abc", "some words", "to be encrypted"},
{true, true, true});
auto array_key =
MakeArrowArrayUtf8({key_32_bytes, key_64_bytes, key_128_bytes, key_256_bytes},
{true, true, true, true});
MakeArrowArrayUtf8({key_16_bytes, key_24_bytes, key_32_bytes},
{true, true, true});

auto array_holder_en = MakeArrowArrayUtf8({"", "", "", ""}, {true, true, true, true});
auto array_holder_en = MakeArrowArrayUtf8({"", "", ""}, {true, true, true});

auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_data, array_key});

Expand Down