Skip to content

Commit

Permalink
feat(transport): add support for setting quota project with envvar (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
codyoss authored Mar 10, 2023
1 parent 225fa6b commit 63c48a6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 25 deletions.
21 changes: 16 additions & 5 deletions internal/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io/ioutil"
"net"
"net/http"
"os"
"time"

"golang.org/x/oauth2"
Expand All @@ -21,6 +22,8 @@ import (
"golang.org/x/oauth2/google"
)

const quotaProjectEnvVar = "GOOGLE_CLOUD_QUOTA_PROJECT"

// Creds returns credential information obtained from DialSettings, or if none, then
// it returns default credential information.
func Creds(ctx context.Context, ds *DialSettings) (*google.Credentials, error) {
Expand Down Expand Up @@ -152,14 +155,22 @@ func selfSignedJWTTokenSource(data []byte, ds *DialSettings) (oauth2.TokenSource
}
}

// QuotaProjectFromCreds returns the quota project from the JSON blob in the provided credentials.
//
// NOTE(cbro): consider promoting this to a field on google.Credentials.
func QuotaProjectFromCreds(cred *google.Credentials) string {
// GetQuotaProject retrieves quota project with precedence being: client option,
// environment variable, creds file.
func GetQuotaProject(creds *google.Credentials, clientOpt string) string {
if clientOpt != "" {
return clientOpt
}
if env := os.Getenv(quotaProjectEnvVar); env != "" {
return env
}
if creds == nil {
return ""
}
var v struct {
QuotaProject string `json:"quota_project_id"`
}
if err := json.Unmarshal(cred.JSON, &v); err != nil {
if err := json.Unmarshal(creds.JSON, &v); err != nil {
return ""
}
return v.QuotaProject
Expand Down
61 changes: 51 additions & 10 deletions internal/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package internal

import (
"context"
"os"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -199,10 +200,9 @@ const validServiceAccountJSON = `{
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/dumba-504%40appspot.gserviceaccount.com"
}`

func TestQuotaProjectFromCreds(t *testing.T) {
func TestGetQuotaProject(t *testing.T) {
ctx := context.Background()

cred, err := credentialsFromJSON(
emptyCred, err := credentialsFromJSON(
ctx,
[]byte(validServiceAccountJSON),
&DialSettings{
Expand All @@ -212,17 +212,13 @@ func TestQuotaProjectFromCreds(t *testing.T) {
if err != nil {
t.Fatalf("got %v, wanted no error", err)
}
if want, got := "", QuotaProjectFromCreds(cred); want != got {
t.Errorf("QuotaProjectFromCreds(validServiceAccountJSON): want %q, got %q", want, got)
}

quotaProjectJSON := []byte(`
{
"type": "authorized_user",
"quota_project_id": "foobar"
}`)

cred, err = credentialsFromJSON(
quotaCred, err := credentialsFromJSON(
ctx,
[]byte(quotaProjectJSON),
&DialSettings{
Expand All @@ -232,8 +228,53 @@ func TestQuotaProjectFromCreds(t *testing.T) {
if err != nil {
t.Fatalf("got %v, wanted no error", err)
}
if want, got := "foobar", QuotaProjectFromCreds(cred); want != got {
t.Errorf("QuotaProjectFromCreds(quotaProjectJSON): want %q, got %q", want, got)

tests := []struct {
name string
cred *google.Credentials
clientOpt string
env string
want string
}{
{
name: "empty all",
cred: nil,
want: "",
},
{
name: "empty cred",
cred: emptyCred,
want: "",
},
{
name: "from cred",
cred: quotaCred,
want: "foobar",
},
{
name: "from opt",
cred: quotaCred,
clientOpt: "clientopt",
want: "clientopt",
},
{
name: "from env",
cred: quotaCred,
env: "envProject",
want: "envProject",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldEnv := os.Getenv(quotaProjectEnvVar)
if tt.env != "" {
os.Setenv(quotaProjectEnvVar, tt.env)
}
if want, got := tt.want, GetQuotaProject(tt.cred, tt.clientOpt); want != got {
t.Errorf("GetQuotaProject(%v, %q): want %q, got %q", tt.cred, tt.clientOpt, want, got)
}
os.Setenv(quotaProjectEnvVar, oldEnv)
})
}
}

Expand Down
6 changes: 1 addition & 5 deletions transport/grpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,10 @@ func dial(ctx context.Context, insecure bool, o *internal.DialSettings) (*grpc.C
return nil, err
}

if o.QuotaProject == "" {
o.QuotaProject = internal.QuotaProjectFromCreds(creds)
}

grpcOpts = append(grpcOpts,
grpc.WithPerRPCCredentials(grpcTokenSource{
TokenSource: oauth.TokenSource{creds.TokenSource},
quotaProject: o.QuotaProject,
quotaProject: internal.GetQuotaProject(creds, o.QuotaProject),
requestReason: o.RequestReason,
}),
)
Expand Down
7 changes: 2 additions & 5 deletions transport/http/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
paramTransport := &parameterTransport{
base: base,
userAgent: settings.UserAgent,
quotaProject: settings.QuotaProject,
requestReason: settings.RequestReason,
}
var trans http.RoundTripper = paramTransport
Expand All @@ -74,6 +73,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
case settings.NoAuth:
// Do nothing.
case settings.APIKey != "":
paramTransport.quotaProject = internal.GetQuotaProject(nil, settings.QuotaProject)
trans = &transport.APIKey{
Transport: trans,
Key: settings.APIKey,
Expand All @@ -83,10 +83,7 @@ func newTransport(ctx context.Context, base http.RoundTripper, settings *interna
if err != nil {
return nil, err
}
if paramTransport.quotaProject == "" {
paramTransport.quotaProject = internal.QuotaProjectFromCreds(creds)
}

paramTransport.quotaProject = internal.GetQuotaProject(creds, settings.QuotaProject)
ts := creds.TokenSource
if settings.ImpersonationConfig == nil && settings.TokenSource != nil {
ts = settings.TokenSource
Expand Down

0 comments on commit 63c48a6

Please sign in to comment.