Skip to content

Commit

Permalink
tls: fix broken --certificate_chain_file flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Chilledheart committed Apr 29, 2024
1 parent cc641ff commit 263c2da
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 45 deletions.
46 changes: 46 additions & 0 deletions src/config/config_tls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include "config/config.hpp"

#include <absl/flags/flag.h>
#include <iostream>
#include "core/utils.hpp"

std::string g_certificate_chain_content;
std::string g_private_key_content;
Expand All @@ -21,3 +23,47 @@ ABSL_FLAG(std::string,
"Tells where to use the specified certificate file to verify the peer. "
"You can override it with YASS_CA_BUNDLE environment variable");
ABSL_FLAG(std::string, capath, "", "Tells where to use the specified certificate directory to verify the peer.");

namespace config {
bool ReadTLSConfigFile() {
do {
static constexpr const size_t kBufferSize = 256 * 1024;
const bool is_server = pType == YASS_SERVER;

ssize_t ret;
if (is_server) {
std::string private_key, private_key_path = absl::GetFlag(FLAGS_private_key_file);
if (private_key_path.empty()) {
std::cerr << "No private key file for certificate provided" << std::endl;
return false;
}
private_key.resize(kBufferSize);
ret = ReadFileToBuffer(private_key_path, absl::MakeSpan(private_key));
if (ret <= 0) {
std::cerr << "private key " << private_key_path << " failed to read" << std::endl;
return -1;
}
private_key.resize(ret);
g_private_key_content = private_key;
std::cerr << "Using private key file: " << private_key_path << std::endl;
}
std::string certificate_chain, certificate_chain_path = absl::GetFlag(FLAGS_certificate_chain_file);
if (is_server && certificate_chain_path.empty()) {
std::cerr << "No certificate file provided" << std::endl;
return false;
}
if (!certificate_chain_path.empty()) {
certificate_chain.resize(kBufferSize);
ret = ReadFileToBuffer(certificate_chain_path, absl::MakeSpan(certificate_chain));
if (ret <= 0) {
std::cerr << "certificate file " << certificate_chain_path << " failed to read" << std::endl;
return false;
}
certificate_chain.resize(ret);
g_certificate_chain_content = certificate_chain;
std::cerr << "Using certificate chain file: " << certificate_chain_path << std::endl;
}
} while (false);
return true;
}
} // namespace config
4 changes: 4 additions & 0 deletions src/config/config_tls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ ABSL_DECLARE_FLAG(bool, insecure_mode);
ABSL_DECLARE_FLAG(std::string, cacert);
ABSL_DECLARE_FLAG(std::string, capath);

namespace config {
bool ReadTLSConfigFile();
} // namespace config

#endif // H_CONFIG_CONFIG_TLS
9 changes: 9 additions & 0 deletions src/config/config_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "config/config.hpp"
#include "config/config_impl.hpp"
#include "config/config_tls.hpp"

#include <absl/flags/internal/program_name.h>
#include <absl/flags/parse.h>
Expand Down Expand Up @@ -83,6 +84,14 @@ void ReadConfigFileAndArguments(int argc, const char** argv) {
absl::ParseCommandLine(argc, const_cast<char**>(argv));
}

// raise some early warning on SSL client/server setups
auto method = absl::GetFlag(FLAGS_method).method;
if (CIPHER_METHOD_IS_TLS(method)) {
if (!config::ReadTLSConfigFile()) {
exit(-1);
}
}

// first line of logging
LOG(WARNING) << "Application starting: " << YASS_APP_TAG << " type: " << ProgramTypeToStr(pType);
LOG(WARNING) << "Last Change: " << YASS_APP_LAST_CHANGE;
Expand Down
12 changes: 3 additions & 9 deletions src/net/content_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,16 +528,10 @@ class ContentServer {
return;
}

std::string certificate_chain_file = absl::GetFlag(FLAGS_certificate_chain_file);
if (!certificate_chain_file.empty()) {
if (SSL_CTX_use_certificate_chain_file(ctx, certificate_chain_file.c_str()) != 1) {
print_openssl_error();
ec = asio::error::bad_descriptor;
return;
}

VLOG(1) << "Using upstream certificate file: " << certificate_chain_file;
if (upstream_certificate_.empty()) {
upstream_certificate_ = g_certificate_chain_content;
}

const auto& cert = upstream_certificate_;
if (!cert.empty()) {
if (ec) {
Expand Down
36 changes: 0 additions & 36 deletions src/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,42 +73,6 @@ int main(int argc, const char* argv[]) {
auto work_guard =
std::make_unique<asio::executor_work_guard<asio::io_context::executor_type>>(io_context.get_executor());

// raise some early warning on SSL server setups
auto method = absl::GetFlag(FLAGS_method).method;
if (CIPHER_METHOD_IS_TLS(method)) {
ssize_t ret;
std::string private_key, private_key_path = absl::GetFlag(FLAGS_private_key_file);
if (private_key_path.empty()) {
LOG(WARNING) << "No private key file for certificate provided";
return -1;
}
static constexpr const size_t kBufferSize = 256 * 1024;
private_key.resize(kBufferSize);
ret = ReadFileToBuffer(private_key_path, absl::MakeSpan(private_key));
if (ret <= 0) {
LOG(WARNING) << "private key " << private_key_path << " failed to read";
return -1;
}
private_key.resize(ret);
g_private_key_content = private_key;
VLOG(1) << "Using private key file: " << private_key_path;

std::string certificate_chain, certificate_chain_path = absl::GetFlag(FLAGS_certificate_chain_file);
if (certificate_chain_path.empty()) {
LOG(WARNING) << "No certificate file provided";
return -1;
}
certificate_chain.resize(kBufferSize);
ret = ReadFileToBuffer(certificate_chain_path, absl::MakeSpan(certificate_chain));
if (ret <= 0) {
LOG(WARNING) << "certificate file " << certificate_chain_path << " failed to read";
return -1;
}
certificate_chain.resize(ret);
g_certificate_chain_content = certificate_chain;
VLOG(1) << "Using certificate chain file: " << certificate_chain_path;
}

std::vector<asio::ip::tcp::endpoint> endpoints;
std::string host_name = absl::GetFlag(FLAGS_server_host);
std::string host_sni = host_name;
Expand Down

0 comments on commit 263c2da

Please sign in to comment.