Skip to content

Commit

Permalink
TestWithAuthToken
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAbides committed Dec 5, 2023
1 parent 10f4199 commit dd8b72d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 29 deletions.
12 changes: 10 additions & 2 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ func NewClient(httpClient *http.Client) *Client {
}

// WithAuthToken returns a copy of the client configured to use the provided token for the Authorization header.
// When token is empty, the returned client will not set the Authorization header.
func (c *Client) WithAuthToken(token string) *Client {
return c.WithTokenSource(staticTokenSource(token))
}
Expand Down Expand Up @@ -463,9 +462,18 @@ func (c *Client) initialize() {
// copy returns a copy of the current client. It must be initialized before use.
func (c *Client) copy() *Client {
c.clientMu.Lock()
httpClient := c.client
if httpClient == nil {
httpClient = &http.Client{}
}
// can't use *c here because that would copy mutexes by value.
clone := Client{
client: c.client,
client: &http.Client{
Transport: httpClient.Transport,
Jar: httpClient.Jar,
CheckRedirect: httpClient.CheckRedirect,
Timeout: httpClient.Timeout,
},
UserAgent: c.UserAgent,
BaseURL: c.BaseURL,
UploadURL: c.UploadURL,
Expand Down
97 changes: 70 additions & 27 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,47 +319,90 @@ func TestClient(t *testing.T) {
}
}

func TestWithAuthToken(t *testing.T) {
token := "gh_test_token"

validate := func(t *testing.T, c *http.Client, token string) {
t.Helper()
want := token
if want != "" {
want = "Bearer " + want
}
gotReq := false
headerVal := ""
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotReq = true
headerVal = r.Header.Get("Authorization")
}))
_, err := c.Get(srv.URL)
assertNilError(t, err)
if !gotReq {
t.Error("request not sent")
func validateAuthHeader(t *testing.T, httpClient *http.Client, want, wantErr *string) {
t.Helper()
var gotReq, hasHeader bool
headerVal := ""
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotReq = true
_, hasHeader = r.Header[headerAuthorization]
headerVal = r.Header.Get(headerAuthorization)
}))
_, err := httpClient.Get(srv.URL)
if wantErr != nil {
if err == nil || !strings.Contains(err.Error(), *wantErr) {
t.Fatalf("error does not contain expected string %q: %v", *wantErr, err)
}
if headerVal != want {
t.Errorf("Authorization header is %v, want %v", headerVal, want)
return
}
assertNilError(t, err)
if !gotReq {
t.Error("request not sent")
}
if want == nil {
if hasHeader {
t.Error("Authorization header is set, but should not be")
}
return
}
if !hasHeader {
t.Error("Authorization header is not set, but should be")
}
if headerVal != *want {
t.Errorf("Authorization header is %q, want %q", headerVal, *want)
}
}

func TestWithAuthToken(t *testing.T) {
t.Run("zero-value Client", func(t *testing.T) {
c := new(Client).WithAuthToken(token)
validate(t, c.Client(), token)
client := new(Client).WithAuthToken("gh_test_token")
validateAuthHeader(t, client.Client(), String("Bearer gh_test_token"), nil)
})

t.Run("NewClient", func(t *testing.T) {
httpClient := &http.Client{}
client := NewClient(httpClient).WithAuthToken(token)
validate(t, client.Client(), token)
ogClient := NewClient(httpClient)
client := ogClient.WithAuthToken("gh_test_token")
validateAuthHeader(t, client.Client(), String("Bearer gh_test_token"), nil)
// make sure the original client isn't setting auth headers now
validate(t, httpClient, "")
validateAuthHeader(t, httpClient, nil, nil)
validateAuthHeader(t, ogClient.Client(), nil, nil)
})

t.Run("NewTokenClient", func(t *testing.T) {
validate(t, NewTokenClient(context.Background(), token).Client(), token)
validateAuthHeader(t, NewTokenClient(context.Background(), "gh_test_token").Client(), String("Bearer gh_test_token"), nil)
})

t.Run("empty token", func(t *testing.T) {
httpClient := &http.Client{}
client := NewClient(httpClient).WithAuthToken("")
validateAuthHeader(t, client.Client(), String("Bearer"), nil)
})
}

type tokenSourceFunc func() (string, error)

func (f tokenSourceFunc) Token(context.Context) (string, error) {
return f()
}

func TestWithTokenSource(t *testing.T) {
tokens := []string{"a", "b"}
source := tokenSourceFunc(func() (string, error) {
if len(tokens) == 0 {
return "", errors.New("no more tokens")
}
token := tokens[0]
tokens = tokens[1:]
return token, nil
})
httpClient := &http.Client{}
client := NewClient(httpClient).WithTokenSource(source)
validateAuthHeader(t, client.Client(), String("Bearer a"), nil)
validateAuthHeader(t, httpClient, nil, nil)
validateAuthHeader(t, client.WithTokenSource(nil).Client(), nil, nil)
validateAuthHeader(t, client.Client(), String("Bearer b"), nil)
validateAuthHeader(t, client.Client(), nil, String("could not get token: no more tokens"))
}

func TestWithEnterpriseURLs(t *testing.T) {
Expand Down

0 comments on commit dd8b72d

Please sign in to comment.