Skip to content

Commit

Permalink
fix(transport): not enable s2a when there is endpoint override (#2368)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmenxk authored Jan 24, 2024
1 parent 2d69d97 commit 73fc7fd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 76 deletions.
21 changes: 3 additions & 18 deletions internal/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
return nil, errUniverseNotSupportedMTLS
}

s2aMTLSEndpoint := settings.DefaultMTLSEndpoint
// If there is endpoint override, honor it.
if settings.Endpoint != "" {
s2aMTLSEndpoint = endpoint
}
s2aAddress := GetS2AAddress()
if s2aAddress == "" {
return &defaultTransportConfig, nil
Expand All @@ -124,7 +119,7 @@ func getTransportConfig(settings *DialSettings) (*transportConfig, error) {
clientCertSource: clientCertSource,
endpoint: endpoint,
s2aAddress: s2aAddress,
s2aMTLSEndpoint: s2aMTLSEndpoint,
s2aMTLSEndpoint: settings.DefaultMTLSEndpoint,
}, nil
}

Expand Down Expand Up @@ -293,12 +288,8 @@ func shouldUseS2A(clientCertSource cert.Source, settings *DialSettings) bool {
if !isGoogleS2AEnabled() {
return false
}
// If DefaultMTLSEndpoint is not set and no endpoint override, skip S2A.
if settings.DefaultMTLSEndpoint == "" && settings.Endpoint == "" {
return false
}
// If MTLS is not enabled for this endpoint, skip S2A.
if !mtlsEndpointEnabledForS2A() {
// If DefaultMTLSEndpoint is not set or has endpoint override, skip S2A.
if settings.DefaultMTLSEndpoint == "" || settings.Endpoint != "" {
return false
}
// If custom HTTP client is provided, skip S2A.
Expand All @@ -308,12 +299,6 @@ func shouldUseS2A(clientCertSource cert.Source, settings *DialSettings) bool {
return true
}

// mtlsEndpointEnabledForS2A checks if the endpoint is indeed MTLS-enabled, so that we can use S2A for MTLS connection.
var mtlsEndpointEnabledForS2A = func() bool {
// TODO(xmenxk): determine this via discovery config.
return true
}

func isGoogleS2AEnabled() bool {
return strings.ToLower(os.Getenv(googleAPIUseS2AEnv)) == "true"
}
77 changes: 19 additions & 58 deletions internal/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,67 +141,60 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
Desc string
InputSettings *DialSettings
S2ARespFunc func() (string, error)
MTLSEnabled func() bool
WantEndpoint string
}{
{
"no client cert, endpoint is MTLS enabled, S2A address not empty",
"has client cert",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
ClientCertSource: dummyClientCertSource,
},
validConfigResp,
func() bool { return true },
testMTLSEndpoint,
},
{
"has client cert",
"no client cert, S2A address not empty",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
ClientCertSource: dummyClientCertSource,
},
validConfigResp,
func() bool { return true },
testMTLSEndpoint,
},
{
"no client cert, endpoint is not MTLS enabled",
"no client cert, S2A address empty",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
},
validConfigResp,
func() bool { return false },
invalidConfigResp,
testRegularEndpoint,
},
{
"no client cert, endpoint is MTLS enabled, S2A address empty",
"no client cert, S2A address not empty, override endpoint",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
Endpoint: testOverrideEndpoint,
},
invalidConfigResp,
func() bool { return true },
testRegularEndpoint,
validConfigResp,
testOverrideEndpoint,
},
{
"no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint",
"no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultMTLSEndpoint: "",
DefaultEndpoint: testRegularEndpoint,
Endpoint: testOverrideEndpoint,
},
validConfigResp,
func() bool { return true },
testOverrideEndpoint,
testRegularEndpoint,
},
}
defer setupTest()()

for _, tc := range testCases {
httpGetMetadataMTLSConfig = tc.S2ARespFunc
mtlsEndpointEnabledForS2A = tc.MTLSEnabled
if tc.InputSettings.ClientCertSource != nil {
os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
} else {
Expand All @@ -221,21 +214,9 @@ func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
Desc string
InputSettings *DialSettings
S2ARespFunc func() (string, error)
MTLSEnabled func() bool
WantEndpoint string
DialFuncNil bool
}{
{
"no client cert, endpoint is MTLS enabled, S2A address not empty",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
},
validConfigResp,
func() bool { return true },
testMTLSEndpoint,
false,
},
{
"has client cert",
&DialSettings{
Expand All @@ -244,43 +225,39 @@ func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
ClientCertSource: dummyClientCertSource,
},
validConfigResp,
func() bool { return true },
testMTLSEndpoint,
true,
},
{
"no client cert, endpoint is not MTLS enabled",
"no client cert, S2A address not empty",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
},
validConfigResp,
func() bool { return false },
testRegularEndpoint,
true,
testMTLSEndpoint,
false,
},
{
"no client cert, endpoint is MTLS enabled, S2A address empty",
"no client cert, S2A address empty",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
},
invalidConfigResp,
func() bool { return true },
testRegularEndpoint,
true,
},
{
"no client cert, endpoint is MTLS enabled, S2A address not empty, override endpoint",
"no client cert, S2A address not empty, override endpoint",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
Endpoint: testOverrideEndpoint,
},
validConfigResp,
func() bool { return true },
testOverrideEndpoint,
false,
true,
},
{
"no client cert, S2A address not empty, but DefaultMTLSEndpoint is not set",
Expand All @@ -289,30 +266,17 @@ func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
DefaultEndpoint: testRegularEndpoint,
},
validConfigResp,
func() bool { return true },
testRegularEndpoint,
true,
},
{
"no client cert, S2A address not empty, override endpoint is set",
&DialSettings{
DefaultMTLSEndpoint: "",
Endpoint: testOverrideEndpoint,
},
validConfigResp,
func() bool { return true },
testOverrideEndpoint,
false,
},
{
"no client cert, endpoint is MTLS enabled, S2A address not empty, custom HTTP client",
"no client cert, S2A address not empty, custom HTTP client",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
HTTPClient: http.DefaultClient,
},
validConfigResp,
func() bool { return true },
testRegularEndpoint,
true,
},
Expand All @@ -322,7 +286,6 @@ func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {

for _, tc := range testCases {
httpGetMetadataMTLSConfig = tc.S2ARespFunc
mtlsEndpointEnabledForS2A = tc.MTLSEnabled
if tc.InputSettings.ClientCertSource != nil {
os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "true")
} else {
Expand All @@ -344,7 +307,6 @@ func TestGetHTTPTransportConfigAndEndpoint_s2a(t *testing.T) {
}

func setupTest() func() {
oldDefaultMTLSEnabled := mtlsEndpointEnabledForS2A
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry
oldUseS2A := os.Getenv(googleAPIUseS2AEnv)
Expand All @@ -355,7 +317,6 @@ func setupTest() func() {

return func() {
httpGetMetadataMTLSConfig = oldHTTPGet
mtlsEndpointEnabledForS2A = oldDefaultMTLSEnabled
configExpiry = oldExpiry
os.Setenv(googleAPIUseS2AEnv, oldUseS2A)
os.Setenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", oldUseClientCert)
Expand Down

0 comments on commit 73fc7fd

Please sign in to comment.