From 263c2dae6daca014a1bfd019899b22adcd0e70c5 Mon Sep 17 00:00:00 2001 From: Keeyou Date: Mon, 29 Apr 2024 14:26:42 +0800 Subject: [PATCH] tls: fix broken --certificate_chain_file flag --- src/config/config_tls.cpp | 46 +++++++++++++++++++++++++++++++++++ src/config/config_tls.hpp | 4 +++ src/config/config_version.cpp | 9 +++++++ src/net/content_server.hpp | 12 +++------ src/server/server.cpp | 36 --------------------------- 5 files changed, 62 insertions(+), 45 deletions(-) diff --git a/src/config/config_tls.cpp b/src/config/config_tls.cpp index 8945fd506..9f366b361 100644 --- a/src/config/config_tls.cpp +++ b/src/config/config_tls.cpp @@ -4,6 +4,8 @@ #include "config/config.hpp" #include +#include +#include "core/utils.hpp" std::string g_certificate_chain_content; std::string g_private_key_content; @@ -21,3 +23,47 @@ ABSL_FLAG(std::string, "Tells where to use the specified certificate file to verify the peer. " "You can override it with YASS_CA_BUNDLE environment variable"); ABSL_FLAG(std::string, capath, "", "Tells where to use the specified certificate directory to verify the peer."); + +namespace config { +bool ReadTLSConfigFile() { + do { + static constexpr const size_t kBufferSize = 256 * 1024; + const bool is_server = pType == YASS_SERVER; + + ssize_t ret; + if (is_server) { + std::string private_key, private_key_path = absl::GetFlag(FLAGS_private_key_file); + if (private_key_path.empty()) { + std::cerr << "No private key file for certificate provided" << std::endl; + return false; + } + private_key.resize(kBufferSize); + ret = ReadFileToBuffer(private_key_path, absl::MakeSpan(private_key)); + if (ret <= 0) { + std::cerr << "private key " << private_key_path << " failed to read" << std::endl; + return -1; + } + private_key.resize(ret); + g_private_key_content = private_key; + std::cerr << "Using private key file: " << private_key_path << std::endl; + } + std::string certificate_chain, certificate_chain_path = absl::GetFlag(FLAGS_certificate_chain_file); + if (is_server && certificate_chain_path.empty()) { + std::cerr << "No certificate file provided" << std::endl; + return false; + } + if (!certificate_chain_path.empty()) { + certificate_chain.resize(kBufferSize); + ret = ReadFileToBuffer(certificate_chain_path, absl::MakeSpan(certificate_chain)); + if (ret <= 0) { + std::cerr << "certificate file " << certificate_chain_path << " failed to read" << std::endl; + return false; + } + certificate_chain.resize(ret); + g_certificate_chain_content = certificate_chain; + std::cerr << "Using certificate chain file: " << certificate_chain_path << std::endl; + } + } while (false); + return true; +} +} // namespace config diff --git a/src/config/config_tls.hpp b/src/config/config_tls.hpp index c3cff3ffa..09f5b5a80 100644 --- a/src/config/config_tls.hpp +++ b/src/config/config_tls.hpp @@ -16,4 +16,8 @@ ABSL_DECLARE_FLAG(bool, insecure_mode); ABSL_DECLARE_FLAG(std::string, cacert); ABSL_DECLARE_FLAG(std::string, capath); +namespace config { +bool ReadTLSConfigFile(); +} // namespace config + #endif // H_CONFIG_CONFIG_TLS diff --git a/src/config/config_version.cpp b/src/config/config_version.cpp index e6788f34a..9d68a7657 100644 --- a/src/config/config_version.cpp +++ b/src/config/config_version.cpp @@ -3,6 +3,7 @@ #include "config/config.hpp" #include "config/config_impl.hpp" +#include "config/config_tls.hpp" #include #include @@ -83,6 +84,14 @@ void ReadConfigFileAndArguments(int argc, const char** argv) { absl::ParseCommandLine(argc, const_cast(argv)); } + // raise some early warning on SSL client/server setups + auto method = absl::GetFlag(FLAGS_method).method; + if (CIPHER_METHOD_IS_TLS(method)) { + if (!config::ReadTLSConfigFile()) { + exit(-1); + } + } + // first line of logging LOG(WARNING) << "Application starting: " << YASS_APP_TAG << " type: " << ProgramTypeToStr(pType); LOG(WARNING) << "Last Change: " << YASS_APP_LAST_CHANGE; diff --git a/src/net/content_server.hpp b/src/net/content_server.hpp index d198e1a46..d169c9c51 100644 --- a/src/net/content_server.hpp +++ b/src/net/content_server.hpp @@ -528,16 +528,10 @@ class ContentServer { return; } - std::string certificate_chain_file = absl::GetFlag(FLAGS_certificate_chain_file); - if (!certificate_chain_file.empty()) { - if (SSL_CTX_use_certificate_chain_file(ctx, certificate_chain_file.c_str()) != 1) { - print_openssl_error(); - ec = asio::error::bad_descriptor; - return; - } - - VLOG(1) << "Using upstream certificate file: " << certificate_chain_file; + if (upstream_certificate_.empty()) { + upstream_certificate_ = g_certificate_chain_content; } + const auto& cert = upstream_certificate_; if (!cert.empty()) { if (ec) { diff --git a/src/server/server.cpp b/src/server/server.cpp index 77f021823..0cadc0e45 100644 --- a/src/server/server.cpp +++ b/src/server/server.cpp @@ -73,42 +73,6 @@ int main(int argc, const char* argv[]) { auto work_guard = std::make_unique>(io_context.get_executor()); - // raise some early warning on SSL server setups - auto method = absl::GetFlag(FLAGS_method).method; - if (CIPHER_METHOD_IS_TLS(method)) { - ssize_t ret; - std::string private_key, private_key_path = absl::GetFlag(FLAGS_private_key_file); - if (private_key_path.empty()) { - LOG(WARNING) << "No private key file for certificate provided"; - return -1; - } - static constexpr const size_t kBufferSize = 256 * 1024; - private_key.resize(kBufferSize); - ret = ReadFileToBuffer(private_key_path, absl::MakeSpan(private_key)); - if (ret <= 0) { - LOG(WARNING) << "private key " << private_key_path << " failed to read"; - return -1; - } - private_key.resize(ret); - g_private_key_content = private_key; - VLOG(1) << "Using private key file: " << private_key_path; - - std::string certificate_chain, certificate_chain_path = absl::GetFlag(FLAGS_certificate_chain_file); - if (certificate_chain_path.empty()) { - LOG(WARNING) << "No certificate file provided"; - return -1; - } - certificate_chain.resize(kBufferSize); - ret = ReadFileToBuffer(certificate_chain_path, absl::MakeSpan(certificate_chain)); - if (ret <= 0) { - LOG(WARNING) << "certificate file " << certificate_chain_path << " failed to read"; - return -1; - } - certificate_chain.resize(ret); - g_certificate_chain_content = certificate_chain; - VLOG(1) << "Using certificate chain file: " << certificate_chain_path; - } - std::vector endpoints; std::string host_name = absl::GetFlag(FLAGS_server_host); std::string host_sni = host_name;