diff --git a/dbms/src/Common/TiFlashSecurity.h b/dbms/src/Common/TiFlashSecurity.h index a02afc9d04e..67daaa40590 100644 --- a/dbms/src/Common/TiFlashSecurity.h +++ b/dbms/src/Common/TiFlashSecurity.h @@ -39,8 +39,6 @@ extern const int INVALID_CONFIG_PARAMETER; class TiFlashSecurityConfig : public ConfigObject { public: - TiFlashSecurityConfig() = default; - explicit TiFlashSecurityConfig(const LoggerPtr & log_) : log(log_) {} @@ -54,12 +52,6 @@ class TiFlashSecurityConfig : public ConfigObject } } - void setLog(const LoggerPtr & log_) - { - std::unique_lock lock(mu); - log = log_; - } - bool hasTlsConfig() { std::unique_lock lock(mu); @@ -94,31 +86,7 @@ class TiFlashSecurityConfig : public ConfigObject bool update(Poco::Util::AbstractConfiguration & config) { std::unique_lock lock(mu); - if (config.has("security")) - { - if (inited && !has_security) - { - LOG_WARNING(log, "Can't add security config online"); - return false; - } - has_security = true; - - bool cert_file_updated = updateCertPath(config); - - if (config.has("security.cert_allowed_cn") && has_tls_config) - { - String verify_cns = config.getString("security.cert_allowed_cn"); - parseAllowedCN(verify_cns); - } - - // Mostly options name are combined with "_", keep this style - if (config.has("security.redact_info_log")) - { - redact_info_log = config.getBool("security.redact_info_log"); - } - return cert_file_updated; - } - else + if (!config.has("security")) { if (inited && has_security) { @@ -128,16 +96,41 @@ class TiFlashSecurityConfig : public ConfigObject { LOG_INFO(log, "security config is not set"); } + return false; } - return false; + + assert(config.has("security")); + if (inited && !has_security) + { + LOG_WARNING(log, "Can't add security config online"); + return false; + } + has_security = true; + + bool cert_file_updated = updateCertPath(config); + + if (config.has("security.cert_allowed_cn") && has_tls_config) + { + String verify_cns = config.getString("security.cert_allowed_cn"); + allowed_common_names = parseAllowedCN(verify_cns); + } + + // Mostly options name are combined with "_", keep this style + if (config.has("security.redact_info_log")) + { + redact_info_log = config.getBool("security.redact_info_log"); + } + return cert_file_updated; } - void parseAllowedCN(String verify_cns) + static std::set parseAllowedCN(String verify_cns) { if (verify_cns.size() > 2 && verify_cns[0] == '[' && verify_cns[verify_cns.size() - 1] == ']') { verify_cns = verify_cns.substr(1, verify_cns.size() - 2); } + + std::set common_names; Poco::StringTokenizer string_tokens(verify_cns, ","); for (const auto & string_token : string_tokens) { @@ -146,8 +139,9 @@ class TiFlashSecurityConfig : public ConfigObject { cn = cn.substr(1, cn.size() - 2); } - allowed_common_names.insert(std::move(cn)); + common_names.insert(std::move(cn)); } + return common_names; } bool checkGrpcContext(const grpc::ServerContext * grpc_context) const @@ -236,18 +230,18 @@ class TiFlashSecurityConfig : public ConfigObject bool updated = false; if (config.has("security.ca_path")) { - new_ca_path = config.getString("security.ca_path"); - miss_ca_path = false; + new_ca_path = Poco::trim(config.getString("security.ca_path")); + miss_ca_path = new_ca_path.empty(); } if (config.has("security.cert_path")) { - new_cert_path = config.getString("security.cert_path"); - miss_cert_path = false; + new_cert_path = Poco::trim(config.getString("security.cert_path")); + miss_cert_path = new_cert_path.empty(); } if (config.has("security.key_path")) { - new_key_path = config.getString("security.key_path"); - miss_key_path = false; + new_key_path = Poco::trim(config.getString("security.key_path")); + miss_key_path = new_key_path.empty(); } if (miss_ca_path && miss_cert_path && miss_key_path) { @@ -322,8 +316,10 @@ class TiFlashSecurityConfig : public ConfigObject String key_path; FilesChangesTracker cert_files; - bool redact_info_log = false; std::set allowed_common_names; + + bool redact_info_log = false; + bool has_tls_config = false; bool has_security = false; bool inited = false; diff --git a/dbms/src/Common/tests/gtest_tiflash_security.cpp b/dbms/src/Common/tests/gtest_tiflash_security.cpp index 7319744ffe1..194d78b1d70 100644 --- a/dbms/src/Common/tests/gtest_tiflash_security.cpp +++ b/dbms/src/Common/tests/gtest_tiflash_security.cpp @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include #include +#include #include #include @@ -30,54 +32,134 @@ class TiFlashSecurityTest : public ext::Singleton TEST(TiFlashSecurityTest, Config) { - TiFlashSecurityConfig tiflash_config; - const auto log = Logger::get(); - tiflash_config.setLog(log); - - tiflash_config.parseAllowedCN(String("[abc,efg]")); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("abc"), 1); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("efg"), 1); - - tiflash_config.allowedCommonNames().clear(); - - tiflash_config.parseAllowedCN(String(R"(["abc","efg"])")); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("abc"), 1); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("efg"), 1); + { + auto cns = TiFlashSecurityConfig::parseAllowedCN(String("[abc,efg]")); + ASSERT_EQ(cns.count("abc"), 1); + ASSERT_EQ(cns.count("efg"), 1); + } - tiflash_config.allowedCommonNames().clear(); + { + auto cns = TiFlashSecurityConfig::parseAllowedCN(String(R"(["abc","efg"])")); + ASSERT_EQ(cns.count("abc"), 1); + ASSERT_EQ(cns.count("efg"), 1); + } - tiflash_config.parseAllowedCN(String("[ abc , efg ]")); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("abc"), 1); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("efg"), 1); + { + auto cns = TiFlashSecurityConfig::parseAllowedCN(String("[ abc , efg ]")); + ASSERT_EQ(cns.count("abc"), 1); + ASSERT_EQ(cns.count("efg"), 1); + } - tiflash_config.allowedCommonNames().clear(); + { + auto cns = TiFlashSecurityConfig::parseAllowedCN(String(R"([ "abc", "efg" ])")); + ASSERT_EQ(cns.count("abc"), 1); + ASSERT_EQ(cns.count("efg"), 1); + } - tiflash_config.parseAllowedCN(String(R"([ "abc", "efg" ])")); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("abc"), 1); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("efg"), 1); + const auto log = Logger::get(); - String test = - R"( + { + auto new_config = loadConfigFromString(R"( [security] ca_path="security/ca.pem" cert_path="security/cert.pem" key_path="security/key.pem" cert_allowed_cn="tidb" - )"; - auto new_config = loadConfigFromString(test); - tiflash_config.update(*new_config); - ASSERT_EQ((int)tiflash_config.allowedCommonNames().count("tidb"), 1); + )"); + TiFlashSecurityConfig tiflash_config(log); + tiflash_config.init(*new_config); + ASSERT_TRUE(tiflash_config.hasTlsConfig()); + ASSERT_EQ(tiflash_config.allowedCommonNames().count("tidb"), 1); + } - test = - R"( + { + auto new_config = loadConfigFromString(R"( [security] cert_allowed_cn="tidb" - )"; - new_config = loadConfigFromString(test); - auto new_tiflash_config = TiFlashSecurityConfig(log); - new_tiflash_config.init(*new_config); - ASSERT_EQ((int)new_tiflash_config.allowedCommonNames().count("tidb"), 0); + )"); + auto new_tiflash_config = TiFlashSecurityConfig(log); + new_tiflash_config.init(*new_config); + ASSERT_FALSE(new_tiflash_config.hasTlsConfig()); + // allowed common names is ignored when tls is not enabled + ASSERT_EQ(new_tiflash_config.allowedCommonNames().count("tidb"), 0); + } +} + +TEST(TiFlashSecurityTest, EmptyConfig) +try +{ + const auto log = Logger::get(); + + for (const auto & c : Strings{ + // empty strings + R"([security] +ca_path="" +cert_path="" +key_path="")", + // non-empty strings with space only + R"([security] +ca_path=" " +cert_path="" +key_path="")", + }) + { + SCOPED_TRACE(fmt::format("case: {}", c)); + TiFlashSecurityConfig tiflash_config(log); + auto new_config = loadConfigFromString(c); + tiflash_config.init(*new_config); + ASSERT_FALSE(tiflash_config.hasTlsConfig()); + } +} +CATCH + +TEST(TiFlashSecurityTest, InvalidConfig) +try +{ + const auto log = Logger::get(); + + for (const auto & c : Strings{ + // only a part of ssl path is set + R"([security] +ca_path="security/ca.pem" +cert_path="" +key_path="")", + R"([security] +ca_path="" +cert_path="security/cert.pem" +key_path="")", + R"([security] +ca_path="" +cert_path="" +key_path="security/key.pem")", + R"([security] +ca_path="" +cert_path="security/cert.pem" +key_path="security/key.pem")", + // comment out + R"([security] +ca_path="security/ca.pem" +#cert_path="security/cert.pem" +key_path="security/key.pem")", + }) + { + SCOPED_TRACE(fmt::format("case: {}", c)); + TiFlashSecurityConfig tiflash_config(log); + auto new_config = loadConfigFromString(c); + try + { + tiflash_config.init(*new_config); + ASSERT_FALSE(true) << "should raise exception"; + } + catch (Exception & e) + { + // has_tls remains false when an exception raise + ASSERT_FALSE(tiflash_config.hasTlsConfig()); + // the error code must be INVALID_CONFIG_PARAMETER + ASSERT_EQ(e.code(), ErrorCodes::INVALID_CONFIG_PARAMETER); + } + } } +CATCH TEST(TiFlashSecurityTest, Update) { diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index 3da05ad9ba3..0799f00984d 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -1393,9 +1393,7 @@ int Server::main(const std::vector & /*args*/) } { // update TiFlashSecurity and related config in client for ssl certificate reload. - bool updated - = global_context->getSecurityConfig()->update(*config); // Whether the cert path or file is updated. - if (updated) + if (bool updated = global_context->getSecurityConfig()->update(*config); updated) { auto raft_config = TiFlashRaftConfig::parseSettings(*config, log); auto cluster_config