Skip to content

Commit

Permalink
fix(internal): support internaloption.WithDefaultUniverseDomain (#2373)
Browse files Browse the repository at this point in the history
  • Loading branch information
quartzmo authored Jan 25, 2024
1 parent ddb3a12 commit b21a1fa
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 43 deletions.
26 changes: 20 additions & 6 deletions internal/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func getClientCertificateSourceAndEndpoint(settings *DialSettings) (cert.Source,
// if settings.DefaultEndpointTemplate == "" {
// return nil, "", errors.New("internaloption.WithDefaultEndpointTemplate is required if option.WithUniverseDomain is not googleapis.com")
// }
endpoint = strings.Replace(settings.DefaultEndpointTemplate, universeDomainPlaceholder, settings.GetUniverseDomain(), 1)
endpoint = resolvedDefaultEndpoint(settings)
}
return clientCertSource, endpoint, nil
}
Expand Down Expand Up @@ -164,27 +164,41 @@ func isClientCertificateEnabled() bool {
// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
func getEndpoint(settings *DialSettings, clientCertSource cert.Source) (string, error) {
if settings.Endpoint == "" {
mtlsMode := getMTLSMode()
if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
if isMTLS(clientCertSource) {
if !settings.IsUniverseDomainGDU() {
return "", errUniverseNotSupportedMTLS
}
return settings.DefaultMTLSEndpoint, nil
}
return settings.DefaultEndpoint, nil
return resolvedDefaultEndpoint(settings), nil
}
if strings.Contains(settings.Endpoint, "://") {
// User passed in a full URL path, use it verbatim.
return settings.Endpoint, nil
}
if settings.DefaultEndpoint == "" {
if resolvedDefaultEndpoint(settings) == "" {
// If DefaultEndpoint is not configured, use the user provided endpoint verbatim.
// This allows a naked "host[:port]" URL to be used with GRPC Direct Path.
return settings.Endpoint, nil
}

// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
return mergeEndpoints(settings.DefaultEndpoint, settings.Endpoint)
return mergeEndpoints(resolvedDefaultEndpoint(settings), settings.Endpoint)
}

func isMTLS(clientCertSource cert.Source) bool {
mtlsMode := getMTLSMode()
return mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto)
}

// resolvedDefaultEndpoint returns the DefaultEndpointTemplate merged with the
// Universe Domain if the DefaultEndpointTemplate is set, otherwise returns the
// deprecated DefaultEndpoint value.
func resolvedDefaultEndpoint(settings *DialSettings) string {
if settings.DefaultEndpointTemplate == "" {
return settings.DefaultEndpoint
}
return strings.Replace(settings.DefaultEndpointTemplate, universeDomainPlaceholder, settings.GetUniverseDomain(), 1)
}

func getMTLSMode() string {
Expand Down
64 changes: 41 additions & 23 deletions internal/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@ var dummyClientCertSource = func(info *tls.CertificateRequestInfo) (*tls.Certifi

func TestGetEndpoint(t *testing.T) {
testCases := []struct {
UserEndpoint string
DefaultEndpoint string
Want string
WantErr bool
UserEndpoint string
DefaultEndpoint string
DefaultEndpointTemplate string
Want string
WantErr bool
}{
{
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://foo.googleapis.com/bar/baz",
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
Want: "https://foo.googleapis.com/bar/baz",
},
{
UserEndpoint: "myhost:3999",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://myhost:3999/bar/baz",
UserEndpoint: "myhost:3999",
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
Want: "https://myhost:3999/bar/baz",
},
{
UserEndpoint: "https://host/path/to/bar",
DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
Want: "https://host/path/to/bar",
UserEndpoint: "https://host/path/to/bar",
DefaultEndpointTemplate: "https://foo.UNIVERSE_DOMAIN/bar/baz",
Want: "https://host/path/to/bar",
},
{
UserEndpoint: "host:123",
Expand All @@ -63,8 +64,10 @@ func TestGetEndpoint(t *testing.T) {

for _, tc := range testCases {
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultEndpointTemplate: tc.DefaultEndpointTemplate,
DefaultUniverseDomain: "googleapis.com",
}, nil)
if tc.WantErr && err == nil {
t.Errorf("want err, got nil err")
Expand All @@ -75,7 +78,7 @@ func TestGetEndpoint(t *testing.T) {
continue
}
if tc.Want != got {
t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpointTemplate, got, tc.Want)
}
}
}
Expand Down Expand Up @@ -118,9 +121,10 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {

for _, tc := range testCases {
got, err := getEndpoint(&DialSettings{
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
Endpoint: tc.UserEndpoint,
DefaultEndpoint: tc.DefaultEndpoint,
DefaultMTLSEndpoint: tc.DefaultMTLSEndpoint,
DefaultUniverseDomain: "googleapis.com",
}, dummyClientCertSource)
if tc.WantErr && err == nil {
t.Errorf("want err, got nil err")
Expand Down Expand Up @@ -174,18 +178,20 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
{
"no client cert, S2A address not empty, override endpoint",
&DialSettings{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpoint: testRegularEndpoint,
Endpoint: testOverrideEndpoint,
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
Endpoint: testOverrideEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
validConfigResp,
testOverrideEndpoint,
},
{
"no client cert, S2A address not empty, DefaultMTLSEndpoint not set",
&DialSettings{
DefaultMTLSEndpoint: "",
DefaultEndpoint: testRegularEndpoint,
DefaultMTLSEndpoint: "",
DefaultEndpointTemplate: testEndpointTemplate,
DefaultUniverseDomain: "googleapis.com",
},
validConfigResp,
testRegularEndpoint,
Expand Down Expand Up @@ -336,6 +342,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpoint: testRegularEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testRegularEndpoint,
},
Expand All @@ -346,6 +353,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
ClientCertSource: dummyClientCertSource,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testMTLSEndpoint,
},
Expand All @@ -356,6 +364,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
UniverseDomain: testUniverseDomain,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testUniverseDomainEndpoint,
},
Expand All @@ -367,6 +376,7 @@ func TestGetHTTPTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
UniverseDomain: testUniverseDomain,
ClientCertSource: dummyClientCertSource,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testUniverseDomainEndpoint,
wantErr: errUniverseNotSupportedMTLS,
Expand Down Expand Up @@ -405,6 +415,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpoint: testRegularEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testRegularEndpoint,
},
Expand All @@ -415,6 +426,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
Endpoint: testOverrideEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testOverrideEndpoint,
},
Expand All @@ -425,6 +437,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
ClientCertSource: dummyClientCertSource,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testMTLSEndpoint,
},
Expand All @@ -436,6 +449,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
ClientCertSource: dummyClientCertSource,
Endpoint: testOverrideEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testOverrideEndpoint,
},
Expand All @@ -446,6 +460,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
DefaultMTLSEndpoint: testMTLSEndpoint,
UniverseDomain: testUniverseDomain,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testUniverseDomainEndpoint,
},
Expand All @@ -457,6 +472,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
UniverseDomain: testUniverseDomain,
Endpoint: testOverrideEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testOverrideEndpoint,
},
Expand All @@ -468,6 +484,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
UniverseDomain: testUniverseDomain,
ClientCertSource: dummyClientCertSource,
DefaultUniverseDomain: "googleapis.com",
},
wantErr: errUniverseNotSupportedMTLS,
},
Expand All @@ -480,6 +497,7 @@ func TestGetGRPCTransportConfigAndEndpoint_UniverseDomain(t *testing.T) {
UniverseDomain: testUniverseDomain,
ClientCertSource: dummyClientCertSource,
Endpoint: testOverrideEndpoint,
DefaultUniverseDomain: "googleapis.com",
},
wantEndpoint: testOverrideEndpoint,
},
Expand Down
21 changes: 10 additions & 11 deletions internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"

"golang.org/x/oauth2"
"google.golang.org/api/internal/cert"
"google.golang.org/api/internal/impersonate"

"golang.org/x/oauth2/google"
Expand Down Expand Up @@ -90,11 +91,11 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g

// Determine configurations for the OAuth2 transport, which is separate from the API transport.
// The OAuth2 transport and endpoint will be configured for mTLS if applicable.
clientCertSource, oauth2Endpoint, err := getClientCertificateSourceAndEndpoint(oauth2DialSettings(ds))
clientCertSource, err := getClientCertificateSource(ds)
if err != nil {
return nil, err
}
params.TokenURL = oauth2Endpoint
params.TokenURL = oAuth2Endpoint(clientCertSource)
if clientCertSource != nil {
tlsConfig := &tls.Config{
GetClientCertificate: clientCertSource,
Expand Down Expand Up @@ -124,6 +125,13 @@ func credentialsFromJSON(ctx context.Context, data []byte, ds *DialSettings) (*g
return cred, err
}

func oAuth2Endpoint(clientCertSource cert.Source) string {
if isMTLS(clientCertSource) {
return google.MTLSTokenURL
}
return google.Endpoint.TokenURL
}

func isSelfSignedJWTFlow(data []byte, ds *DialSettings) (bool, error) {
// For non-GDU universe domains, token exchange is impossible and services
// must support self-signed JWTs with scopes.
Expand Down Expand Up @@ -196,15 +204,6 @@ func impersonateCredentials(ctx context.Context, creds *google.Credentials, ds *
}, nil
}

// oauth2DialSettings returns the settings to be used by the OAuth2 transport, which is separate from the API transport.
func oauth2DialSettings(ds *DialSettings) *DialSettings {
var ods DialSettings
ods.DefaultEndpoint = google.Endpoint.TokenURL
ods.DefaultMTLSEndpoint = google.MTLSTokenURL
ods.ClientCertSource = ds.ClientCertSource
return &ods
}

// customHTTPClient constructs an HTTPClient using the provided tlsConfig, to support mTLS.
func customHTTPClient(tlsConfig *tls.Config) *http.Client {
trans := baseTransport()
Expand Down
18 changes: 15 additions & 3 deletions internal/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,27 @@ func (ds *DialSettings) Validate() error {
return nil
}

// UniverseDomain returns the default service domain for a given Cloud universe.
// GetDefaultUniverseDomain returns the default service domain for a given Cloud
// universe, as configured with internaloption.WithDefaultUniverseDomain.
// The default value is "googleapis.com".
func (ds *DialSettings) GetDefaultUniverseDomain() string {
if ds.DefaultUniverseDomain == "" {
return universeDomainDefault
}
return ds.DefaultUniverseDomain
}

// GetUniverseDomain returns the default service domain for a given Cloud
// universe, as configured with option.WithUniverseDomain.
// The default value is the value of GetDefaultUniverseDomain, as configured
// with internaloption.WithDefaultUniverseDomain.
func (ds *DialSettings) GetUniverseDomain() string {
if ds.UniverseDomain == "" {
return universeDomainDefault
return ds.GetDefaultUniverseDomain()
}
return ds.UniverseDomain
}

func (ds *DialSettings) IsUniverseDomainGDU() bool {
return ds.GetUniverseDomain() == universeDomainDefault
return ds.GetUniverseDomain() == ds.GetDefaultUniverseDomain()
}

0 comments on commit b21a1fa

Please sign in to comment.