Skip to content

Commit

Permalink
Support client certificate param
Browse files Browse the repository at this point in the history
  • Loading branch information
UgnineSirdis committed Aug 26, 2024
1 parent 48bd4dd commit 7f1443d
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 5 deletions.
1 change: 1 addition & 0 deletions ydb/core/driver_lib/cli_base/cli_cmds_root.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class TClientCommandRootLite : public TClientCommandRootKikimrBase {
throw TMisuseException() << message;
}
ParseCaCerts(config);
ParseClientCert(config);
config.Address = Address;

if (!hostname) {
Expand Down
1 change: 1 addition & 0 deletions ydb/core/driver_lib/cli_utils/cli_cmds_root.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TClientCommandRoot : public TClientCommandRootKikimrBase {
config.EnableSsl = endpoint.EnableSsl.GetRef();
}
ParseCaCerts(config);
ParseClientCert(config);

CommandConfig.ClientConfig = NYdbGrpc::TGRpcClientConfig(endpoint.Address);
if (config.EnableSsl) {
Expand Down
1 change: 1 addition & 0 deletions ydb/public/lib/ydb_cli/commands/ydb_command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ TDriverConfig TYdbCommand::CreateDriverConfig(const TConfig& config) {
driverConfig.UseSecureConnection(config.CaCerts);
if (config.IsNetworkIntensive)
driverConfig.SetNetworkThreadsNum(16);
driverConfig.UseClientCertificate(config.ClientCert, config.ClientCertPrivateKey);

return driverConfig;
}
Expand Down
88 changes: 84 additions & 4 deletions ydb/public/lib/ydb_cli/commands/ydb_root_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,19 @@ void TClientCommandRootCommon::Parse(TConfig& config) {
TClientCommandRootBase::Parse(config);
ParseDatabase(config);
ParseCaCerts(config);
ParseClientCert(config);
ParseIamEndpoint(config);

config.VerbosityLevel = std::min(static_cast<TConfig::EVerbosityLevel>(VerbosityLevel), TConfig::EVerbosityLevel::DEBUG);
}

namespace {
inline void PrintSettingFromProfile(const TString& setting, std::shared_ptr<IProfile> profile, bool explicitOption) {
inline void PrintSettingFromProfile(const TString& setting, const std::shared_ptr<IProfile>& profile, bool explicitOption) {
Cerr << "Using " << setting << " due to configuration in" << (explicitOption ? "" : " active") << " profile \""
<< profile->GetName() << "\"" << (explicitOption ? " from explicit --profile option" : "") << Endl;
}

inline TString GetProfileSource(std::shared_ptr<IProfile> profile, bool explicitOption) {
inline TString GetProfileSource(const std::shared_ptr<IProfile>& profile, bool explicitOption) {
Y_ABORT_UNLESS(profile, "No profile to get source");
if (explicitOption) {
return TStringBuilder() << "profile \"" << profile->GetName() << "\" from explicit --profile option";
Expand All @@ -360,14 +361,41 @@ namespace {
}
}

bool TClientCommandRootCommon::TryGetParamFromProfile(const TString& name, std::shared_ptr<IProfile> profile, bool explicitOption,
bool TClientCommandRootCommon::TryGetParamFromProfile(const TString& name, const std::shared_ptr<IProfile>& profile, bool explicitOption,
std::function<bool(const TString&, const TString&, bool)> callback) {
if (profile && profile->Has(name)) {
return callback(profile->GetValue(name).as<TString>(), GetProfileSource(profile, explicitOption), explicitOption);
}
return false;
}

bool TClientCommandRootCommon::TryGetParamsPackFromProfile(const std::shared_ptr<IProfile>& profile, bool explicitOption,
std::function<bool(const TString& /*source*/, bool /*explicit*/, const std::vector<TString>& /*values*/)> callback,
const std::initializer_list<TString>& names) {
if (!profile) {
return false;
}
bool hasAtLeastOne = false;
bool doesNotHaveAtLeastOne = false;
for (const TString& name : names) {
hasAtLeastOne |= profile->Has(name);
doesNotHaveAtLeastOne |= !profile->Has(name);
}
if (hasAtLeastOne && doesNotHaveAtLeastOne) {
throw TMisuseException()
<< "Either all or none of the following options must be set in one profile: " << JoinSeq(", ", names);
}
if (hasAtLeastOne) {
std::vector<TString> values;
values.reserve(names.size());
for (const TString& name : names) {
values.emplace_back(profile->GetValue(name).as<TString>());
}
return callback(GetProfileSource(profile, explicitOption), explicitOption, values);
}
return false;
}

void TClientCommandRootCommon::ParseCaCerts(TConfig& config) {
auto getCaFile = [this, &config] (const TString& param, const TString& sourceText, bool explicitOption) {
if (!IsCaCertsFileSet && (explicitOption || !Profile)) {
Expand Down Expand Up @@ -395,16 +423,68 @@ void TClientCommandRootCommon::ParseCaCerts(TConfig& config) {
}
}

void TClientCommandRootCommon::ParseClientCert(TConfig& config) {
auto getClientCertFiles = [this, &config] (const TString& sourceText, bool explicitOption, const std::vector<TString>& values) {
Y_ABORT_UNLESS(values.size() == 2);
const TString& clientCertFileParam = values[0];
const TString& clientCertPrivateKeyFileParam = values[1];
if (!IsClientCertFileSet && (explicitOption || !Profile)) {
config.ClientCertFile = clientCertFileParam;
config.ClientCertPrivateKeyFile = clientCertPrivateKeyFileParam;
IsClientCertFileSet = true;
GetClientCert(config);
}
if (!IsVerbose()) {
return true;
}
Cerr << "Using client certificate from file: " << clientCertFileParam << Endl;
config.ConnectionParams["client-cert-file"].push_back({clientCertFileParam, sourceText});
config.ConnectionParams["client-cert-key-file"].push_back({clientCertPrivateKeyFileParam, sourceText});
return false;
};
// Priority 1. Explicit --client-cert-file/--client-cert-key-file options
if (!ClientCertFile.empty() || !ClientCertPrivateKeyFile.empty()) {
if (ClientCertFile.empty() || ClientCertPrivateKeyFile.empty()) { // One option is set, another is not set
throw TMisuseException()
<< "Both \"client-cert-file\" and \"client-cert-key-file\" options must be provided.";
}
if (ClientCertFile && getClientCertFiles("explicit --client-cert-file/--client-cert-key-file option", true, { ClientCertFile, ClientCertPrivateKeyFile })) {
return;
}
}
// Priority 2. Explicit --profile option
if (TryGetParamsPackFromProfile(Profile, true, getClientCertFiles, { "client-cert-file", "client-cert-key-file" })) {
return;
}
// Priority 3. Active profile (if --profile option is not specified)
if (TryGetParamsPackFromProfile(ProfileManager->GetActiveProfile(), false, getClientCertFiles, { "client-cert-file", "client-cert-key-file" })) {
return;
}
}

void TClientCommandRootCommon::GetCaCerts(TConfig& config) {
if (!config.EnableSsl && !config.CaCertsFile.empty()) {
throw TMisuseException()
<< "\"ca-file\" option provided for a non-ssl connection. Use grpcs:// prefix for host to connect using SSL.";
<< "\"ca-file\" option is provided for a non-ssl connection. Use grpcs:// prefix for host to connect using SSL.";
}
if (!config.CaCertsFile.empty()) {
config.CaCerts = ReadFromFile(config.CaCertsFile, "CA certificates");
}
}

void TClientCommandRootCommon::GetClientCert(TConfig& config) {
if (!config.EnableSsl && (!config.ClientCertFile.empty() || !config.ClientCertPrivateKeyFile.empty())) {
throw TMisuseException()
<< "\"client-cert-file\"/\"client-cert-key-file\" options are provided for a non-ssl connection. Use grpcs:// prefix for host to connect using SSL.";
}
if (!config.ClientCertFile.empty()) {
config.ClientCert = ReadFromFile(config.ClientCertFile, "Client certificate");
}
if (!config.ClientCertPrivateKeyFile.empty()) {
config.ClientCertPrivateKey = ReadFromFile(config.ClientCertPrivateKeyFile, "Client certificate private key");
}
}

void TClientCommandRootCommon::ParseAddress(TConfig& config) {
auto getAddress = [this, &config] (const TString& param, const TString& sourceText, bool explicitOption) {
TString address;
Expand Down
11 changes: 10 additions & 1 deletion ydb/public/lib/ydb_cli/commands/ydb_root_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,20 @@ class TClientCommandRootCommon : public TClientCommandRootBase {
void ParseDatabase(TConfig& config);
void ParseIamEndpoint(TConfig& config);
void ParseCaCerts(TConfig& config) override;
void ParseClientCert(TConfig& config) override;
void GetAddressFromString(TConfig& config, TString* result = nullptr);
bool ParseProtocolNoConfig(TString& message);
void GetCaCerts(TConfig& config);
bool TryGetParamFromProfile(const TString& name, std::shared_ptr<IProfile> profile, bool explicitOption,
void GetClientCert(TConfig& config);
bool TryGetParamFromProfile(const TString& name, const std::shared_ptr<IProfile>& profile, bool explicitOption,
std::function<bool(const TString&, const TString&, bool)> callback);

// Gets more than one params from one profile source.
// Checks that if at least one param of pack is set, then all must be set (to avoid ambiguity of params source).
bool TryGetParamsPackFromProfile(const std::shared_ptr<IProfile>& profile, bool explicitOption,
std::function<bool(const TString& /*source*/, bool /*explicit*/, const std::vector<TString>& /*values*/)> callback,
const std::initializer_list<TString>& names);

TString Database;

ui32 VerbosityLevel = 0;
Expand Down Expand Up @@ -84,6 +92,7 @@ class TClientCommandRootCommon : public TClientCommandRootBase {
bool IsDatabaseSet = false;
bool IsIamEndpointSet = false;
bool IsCaCertsFileSet = false;
bool IsClientCertFileSet = false;
bool IsAuthSet = false;
};

Expand Down
4 changes: 4 additions & 0 deletions ydb/public/lib/ydb_cli/common/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class TClientCommand {
TString Database;
TString CaCerts;
TString CaCertsFile;
TString ClientCert;
TString ClientCertPrivateKey;
TString ClientCertFile;
TString ClientCertPrivateKeyFile;
TMap<TString, TVector<TConnectionParam>> ConnectionParams;
bool EnableSsl = false;
bool IsNetworkIntensive = false;
Expand Down
22 changes: 22 additions & 0 deletions ydb/public/lib/ydb_cli/common/root.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ void TClientCommandRootBase::Config(TConfig& config) {
"Path to a file containing the PEM encoding of the server root certificates for tls connections.\n"
"If this parameter is empty, the default roots will be used.")
.RequiredArgument("PATH").StoreResult(&CaCertsFile);
opts.AddLongOption("client-cert-file",
"Path to a file containing the PEM encoding of the client certificate for tls connections")
.RequiredArgument("PATH").StoreResult(&ClientCertFile);
opts.AddLongOption("client-cert-key-file",
"Path to a file containing the PEM encoding of the client certificate private key for tls connections")
.RequiredArgument("PATH").StoreResult(&ClientCertPrivateKeyFile);

opts.SetCustomUsage(config.ArgV[0]);
config.SetFreeArgsMin(1);
Expand Down Expand Up @@ -94,6 +100,22 @@ void TClientCommandRootBase::ParseCaCerts(TConfig& config) {
config.CaCerts = ReadFromFile(CaCertsFile, "CA certificates");
}

void TClientCommandRootBase::ParseClientCert(TConfig& config) {
if (ClientCertFile.empty() && ClientCertPrivateKeyFile.empty()) {
return;
}
if (ClientCertFile.empty() || ClientCertPrivateKeyFile.empty()) { // One option is set, another is not set
throw TMisuseException()
<< "Both \"client-cert-file\" and \"client-cert-key-file\" options must be provided.";
}
if (!config.EnableSsl) {
throw TMisuseException()
<< "\"client-cert-file\" option provided for a non-ssl connection. Use grpcs:// prefix for host to connect using SSL.";
}
config.ClientCert = ReadFromFile(ClientCertFile, "Client certificate");
config.ClientCertPrivateKey = ReadFromFile(ClientCertPrivateKeyFile, "Client certificate private key");
}

void TClientCommandRootBase::ParseCredentials(TConfig& config) {
ParseToken(Token, TokenFile, "YDB_TOKEN", true);
if (!Token.empty()) {
Expand Down
3 changes: 3 additions & 0 deletions ydb/public/lib/ydb_cli/common/root.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class TClientCommandRootBase : public TClientCommandTree {
TString Token;
TString TokenFile;
TString CaCertsFile;
TString ClientCertFile;
TString ClientCertPrivateKeyFile;

virtual void Config(TConfig& config) override;
virtual void Parse(TConfig& config) override;
Expand All @@ -27,6 +29,7 @@ class TClientCommandRootBase : public TClientCommandTree {
void ParseToken(TString& token, TString& tokenFile, const TString& envName, bool useDefaultToken = false);
bool ParseProtocol(TConfig& config, TString& message);
virtual void ParseCaCerts(TConfig& config);
virtual void ParseClientCert(TConfig& config);
virtual void ParseCredentials(TConfig& config);
virtual void ParseAddress(TConfig& config) = 0;
};
Expand Down

0 comments on commit 7f1443d

Please sign in to comment.