Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid retrying for 5 minutes after failed key retrieval. #102

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/scitokens_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ remove_issuer_entry(sqlite3 *db, const std::string &issuer, bool new_transaction


bool
scitokens::Validator::get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update) {
scitokens::Validator::get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update, int64_t &expires) {
auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {return false;}

Expand Down Expand Up @@ -177,6 +177,7 @@ scitokens::Validator::get_public_keys_from_db(const std::string issuer, int64_t
sqlite3_close(db);
return false;
}
expires = expiry;
sqlite3_close(db);
iter = top_obj.find("next_update");
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
Expand Down
10 changes: 6 additions & 4 deletions src/scitokens_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ Validator::get_jwks(const std::string &issuer)
{
auto now = std::time(NULL);
picojson::value jwks;
int64_t next_update;
if (get_public_keys_from_db(issuer, now, jwks, next_update)) {
int64_t next_update, expires;
if (get_public_keys_from_db(issuer, now, jwks, next_update, expires)) {
return jwks.serialize();
}
return std::string("{\"keys\": []}");
Expand Down Expand Up @@ -578,13 +578,15 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid,
picojson::value keys;
int64_t next_update, expires;
auto now = std::time(NULL);
if (get_public_keys_from_db(issuer, now, keys, next_update)) {
if (get_public_keys_from_db(issuer, now, keys, next_update, expires)) {
if (now > next_update) {
try {
get_public_keys_from_web(issuer, SimpleCurlGet::default_timeout, keys, next_update, expires);
store_public_keys(issuer, keys, next_update, expires);
} catch (std::runtime_error &) {
// ignore the exception: we have a valid set of keys already/
// ignore the exception: we have a valid set of keys already. However, we don't want to continuously
// hammer the upstream server which is not currently working ... move forward the next_update by 5 minutes.
store_public_keys(issuer, keys, now + 300, expires);
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/scitokens_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ class Validator {
private:
void get_public_key_pem(const std::string &issuer, const std::string &kid, std::string &public_pem, std::string &algorithm);
static void get_public_keys_from_web(const std::string &issuer, unsigned timeout, picojson::value &keys, int64_t &next_update, int64_t &expires);
static bool get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update);
static bool get_public_keys_from_db(const std::string issuer, int64_t now, picojson::value &keys, int64_t &next_update, int64_t &expires);
static bool store_public_keys(const std::string &issuer, const picojson::value &keys, int64_t next_update, int64_t expires);

bool m_validate_all_claims{true};
Expand Down
293 changes: 293 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#include "../src/scitokens.h"

#include <pwd.h>
#include <memory>
#include <gtest/gtest.h>

#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ec.h>
#include <openssl/pem.h>

#ifndef PICOJSON_USE_INT64
#define PICOJSON_USE_INT64
#endif
#include <picojson/picojson.h>
#include <sqlite3.h>

namespace {

const char ec_private[] = "-----BEGIN EC PRIVATE KEY-----\n"
Expand All @@ -27,6 +39,216 @@ const char ec_public_2[] = "-----BEGIN PUBLIC KEY-----\n"
"XWCq4E/g2ME/uBOdP8RE0tqle8fxYcaPikgMcppGq2ycTiLGgEYXgsq2JA==\n"
"-----END PUBLIC KEY-----\n";

/**
* Duplicate of get_cache_file from scitokens_cache.cpp; used for direct
* SQLite manipulation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What’s the purpose of these duplicate functions rather than calling the library functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly did this because the underlying functions are internal and not exported by the library as symbols... any ideas on how to get them into the executable instead of duplicating?

*/
std::string
get_cache_file() {

const char *xdg_cache_home = getenv("XDG_CACHE_HOME");

auto bufsize = sysconf(_SC_GETPW_R_SIZE_MAX);
bufsize = (bufsize == -1) ? 16384 : bufsize;

std::unique_ptr<char[]> buf(new char[bufsize]);

std::string home_dir;
struct passwd pwd, *result = NULL;
getpwuid_r(geteuid(), &pwd, buf.get(), bufsize, &result);
if (result && result->pw_dir) {
home_dir = result->pw_dir;
home_dir += "/.cache";
}

std::string cache_dir(xdg_cache_home ? xdg_cache_home : home_dir.c_str());
if (cache_dir.size() == 0) {
return "";
}

int r = mkdir(cache_dir.c_str(), 0700);
if ((r < 0) && errno != EEXIST) {
return "";
}

std::string keycache_dir = cache_dir + "/scitokens";
r = mkdir(keycache_dir.c_str(), 0700);
if ((r < 0) && errno != EEXIST) {
return "";
}

std::string keycache_file = keycache_dir + "/scitokens_cpp.sqllite";
// Assume this isn't needed; we'll trigger it via the "real" cache routines.
//initialize_cachedb(keycache_file);

return keycache_file;
}

/**
* Duplicate of remove_issuer_entry from scitokens_cache.cpp; used for direct cache manipulation
*/
void
remove_issuer_entry(sqlite3 *db, const std::string &issuer, bool new_transaction) {

if (new_transaction) sqlite3_exec(db, "BEGIN", 0, 0 , 0);

sqlite3_stmt *stmt;
int rc = sqlite3_prepare_v2(db, "DELETE FROM keycache WHERE issuer = ?", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return;
}

rc = sqlite3_step(stmt);
if (rc != SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return;
}

sqlite3_finalize(stmt);

if (new_transaction) sqlite3_exec(db, "COMMIT", 0, 0 , 0);
}

/**
* Duplicate of store_public_keys from scitokens_cache.cpp; used for direct cache manipulation.
*/
bool
store_public_keys(const std::string &issuer, const std::string &keys, int64_t next_update, int64_t expires) {

picojson::value json_obj;
auto err = picojson::parse(json_obj, keys);
if (!err.empty() || !json_obj.is<picojson::object>()) {
return false;
}

picojson::object top_obj;
top_obj["jwks"] = json_obj;
top_obj["next_update"] = picojson::value(next_update);
top_obj["expires"] = picojson::value(expires);
picojson::value db_value(top_obj);
std::string db_str = db_value.serialize();

auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {return false;}

sqlite3 *db;
int rc = sqlite3_open(cache_fname.c_str(), &db);
if (rc) {
sqlite3_close(db);
return false;
}

sqlite3_exec(db, "BEGIN", 0, 0 , 0);

remove_issuer_entry(db, issuer, false);

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "INSERT INTO keycache VALUES (?, ?)", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 2, db_str.c_str(), db_str.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

rc = sqlite3_step(stmt);
if (rc != SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

sqlite3_exec(db, "COMMIT", 0, 0 , 0);

sqlite3_finalize(stmt);
sqlite3_close(db);
return true;
}

bool
get_public_keys_from_db(const std::string issuer, int64_t &expires, int64_t &next_update) {
auto cache_fname = get_cache_file();
if (cache_fname.size() == 0) {return false;}

sqlite3 *db;
int rc = sqlite3_open(cache_fname.c_str(), &db);
if (rc) {
sqlite3_close(db);
return false;
}

sqlite3_stmt *stmt;
rc = sqlite3_prepare_v2(db, "SELECT keys from keycache where issuer = ?", -1, &stmt, NULL);
if (rc != SQLITE_OK) {
sqlite3_close(db);
return false;
}

if (sqlite3_bind_text(stmt, 1, issuer.c_str(), issuer.size(), SQLITE_STATIC) != SQLITE_OK) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}

rc = sqlite3_step(stmt);
if (rc == SQLITE_ROW) {
const unsigned char * data = sqlite3_column_text(stmt, 0);
std::string metadata(reinterpret_cast<const char *>(data));
sqlite3_finalize(stmt);
picojson::value json_obj;
auto err = picojson::parse(json_obj, metadata);
if (!err.empty() || !json_obj.is<picojson::object>()) {
sqlite3_close(db);
return false;
}
auto top_obj = json_obj.get<picojson::object>();
auto iter = top_obj.find("jwks");
auto keys_local = iter->second;
iter = top_obj.find("expires");
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
sqlite3_close(db);
return false;
}
auto expiry = iter->second.get<int64_t>();
sqlite3_close(db);
iter = top_obj.find("next_update");
if (iter == top_obj.end() || !iter->second.is<int64_t>()) {
next_update = expiry - 4*3600;
} else {
next_update = iter->second.get<int64_t>();
}
expires = expiry;
return true;
} else if (rc == SQLITE_DONE) {
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
} else {
// TODO: log error?
sqlite3_finalize(stmt);
sqlite3_close(db);
return false;
}
}

TEST(SciTokenTest, CreateToken) {
SciToken token = scitoken_create(nullptr);
ASSERT_TRUE(token != nullptr);
Expand Down Expand Up @@ -63,6 +285,7 @@ class KeycacheTest : public ::testing::Test
{
protected:
std::string demo_scitokens_url = "https://demo.scitokens.org";
std::string demo_invalid_url = "https://demo.scitokens.org/invalid";

void SetUp() override {
char *err_msg;
Expand All @@ -77,6 +300,76 @@ class KeycacheTest : public ::testing::Test
};


// Emulate the case of an issuer failure. Store a public key that
// is in the need of an update. Make sure, on failure, the next_update
// is 5 minutes ahead of the present.
TEST_F(KeycacheTest, FailureTest) {
time_t now = time(NULL);
const time_t expiry = now + 86400;
// Insert a public key that requires an update on next token verification.
ASSERT_TRUE(store_public_keys(demo_invalid_url, demo_scitokens2, now - 600, expiry));

// Create a new token with an invalid signature.
OpenSSL_add_all_algorithms();
ERR_load_BIO_strings();
ERR_load_crypto_strings();
auto outbio = BIO_new(BIO_s_mem());
ASSERT_TRUE(outbio != nullptr);
auto eccgrp = OBJ_txt2nid("secp256k1");
auto ecc = EC_KEY_new_by_curve_name(eccgrp);
ASSERT_TRUE(1 == EC_KEY_generate_key(ecc));

auto pkey = EVP_PKEY_new();
ASSERT_TRUE(1 == EVP_PKEY_assign_EC_KEY(pkey, ecc));
ASSERT_TRUE(1 == PEM_write_bio_PrivateKey(outbio, pkey, NULL, NULL, 0, 0, NULL));

char *pem_data;
long pem_len = BIO_get_mem_data(outbio, &pem_data);
std::string pem_str(pem_data, pem_len);

// Generate a serialized token from the new key.
auto key = scitoken_key_create("test_key", "ES256", "", pem_str.c_str(), nullptr);
ASSERT_TRUE(key != nullptr);

auto token = scitoken_create(key);
ASSERT_TRUE(token != nullptr);

auto rv = scitoken_set_claim_string(token, "iss", demo_invalid_url.c_str(), nullptr);
ASSERT_TRUE(rv == 0);

rv = scitoken_set_claim_string(token, "sub", "test_user", nullptr);
ASSERT_TRUE(rv == 0);

scitoken_set_lifetime(token, 86400);

char *token_encoded;
rv = scitoken_serialize(token, &token_encoded, nullptr);
ASSERT_TRUE(rv == 0);
std::string token_str(token_encoded);
free(token_encoded);

// Try to deserialize the newly generated token. Should fail as the key doesn't match.
auto token_read = scitoken_create(nullptr);
ASSERT_TRUE(token_read != nullptr);
rv = scitoken_deserialize_v2(token_str.c_str(), token_read, nullptr, nullptr);
ASSERT_FALSE(rv == 0);

// Now, for the real test -- what's the value of expired and next_update?
int64_t new_expiry, new_next_update;
ASSERT_TRUE(get_public_keys_from_db(demo_invalid_url, new_expiry, new_next_update));

EXPECT_EQ(new_expiry, expiry);
EXPECT_GE(new_next_update, now + 300);

// Second test: if the expiration is behind us, fetching the key should trigger
// a deletion of the key cache.
ASSERT_TRUE(store_public_keys(demo_invalid_url, demo_scitokens2, now - 600, now - 600));

rv = scitoken_deserialize_v2(token_str.c_str(), token_read, nullptr, nullptr);

ASSERT_FALSE(get_public_keys_from_db(demo_invalid_url, new_expiry, new_next_update));
}

TEST_F(KeycacheTest, RefreshTest) {
char *err_msg;
auto rv = keycache_refresh_jwks(demo_scitokens_url.c_str(), &err_msg);
Expand Down