diff --git a/github/github.go b/github/github.go index 10bc7007a30..a2bd2a76459 100644 --- a/github/github.go +++ b/github/github.go @@ -346,7 +346,6 @@ func NewClient(httpClient *http.Client) *Client { } // WithAuthToken returns a copy of the client configured to use the provided token for the Authorization header. -// When token is empty, the returned client will not set the Authorization header. func (c *Client) WithAuthToken(token string) *Client { return c.WithTokenSource(staticTokenSource(token)) } @@ -463,9 +462,18 @@ func (c *Client) initialize() { // copy returns a copy of the current client. It must be initialized before use. func (c *Client) copy() *Client { c.clientMu.Lock() + httpClient := c.client + if httpClient == nil { + httpClient = &http.Client{} + } // can't use *c here because that would copy mutexes by value. clone := Client{ - client: c.client, + client: &http.Client{ + Transport: httpClient.Transport, + Jar: httpClient.Jar, + CheckRedirect: httpClient.CheckRedirect, + Timeout: httpClient.Timeout, + }, UserAgent: c.UserAgent, BaseURL: c.BaseURL, UploadURL: c.UploadURL, diff --git a/github/github_test.go b/github/github_test.go index 3e4ea79d26f..83adf1a30cc 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -319,47 +319,90 @@ func TestClient(t *testing.T) { } } -func TestWithAuthToken(t *testing.T) { - token := "gh_test_token" - - validate := func(t *testing.T, c *http.Client, token string) { - t.Helper() - want := token - if want != "" { - want = "Bearer " + want - } - gotReq := false - headerVal := "" - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotReq = true - headerVal = r.Header.Get("Authorization") - })) - _, err := c.Get(srv.URL) - assertNilError(t, err) - if !gotReq { - t.Error("request not sent") +func validateAuthHeader(t *testing.T, httpClient *http.Client, want, wantErr *string) { + t.Helper() + var gotReq, hasHeader bool + headerVal := "" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReq = true + _, hasHeader = r.Header[headerAuthorization] + headerVal = r.Header.Get(headerAuthorization) + })) + _, err := httpClient.Get(srv.URL) + if wantErr != nil { + if err == nil || !strings.Contains(err.Error(), *wantErr) { + t.Fatalf("error does not contain expected string %q: %v", *wantErr, err) } - if headerVal != want { - t.Errorf("Authorization header is %v, want %v", headerVal, want) + return + } + assertNilError(t, err) + if !gotReq { + t.Error("request not sent") + } + if want == nil { + if hasHeader { + t.Error("Authorization header is set, but should not be") } + return } + if !hasHeader { + t.Error("Authorization header is not set, but should be") + } + if headerVal != *want { + t.Errorf("Authorization header is %q, want %q", headerVal, *want) + } +} +func TestWithAuthToken(t *testing.T) { t.Run("zero-value Client", func(t *testing.T) { - c := new(Client).WithAuthToken(token) - validate(t, c.Client(), token) + client := new(Client).WithAuthToken("gh_test_token") + validateAuthHeader(t, client.Client(), String("Bearer gh_test_token"), nil) }) t.Run("NewClient", func(t *testing.T) { httpClient := &http.Client{} - client := NewClient(httpClient).WithAuthToken(token) - validate(t, client.Client(), token) + ogClient := NewClient(httpClient) + client := ogClient.WithAuthToken("gh_test_token") + validateAuthHeader(t, client.Client(), String("Bearer gh_test_token"), nil) // make sure the original client isn't setting auth headers now - validate(t, httpClient, "") + validateAuthHeader(t, httpClient, nil, nil) + validateAuthHeader(t, ogClient.Client(), nil, nil) }) t.Run("NewTokenClient", func(t *testing.T) { - validate(t, NewTokenClient(context.Background(), token).Client(), token) + validateAuthHeader(t, NewTokenClient(context.Background(), "gh_test_token").Client(), String("Bearer gh_test_token"), nil) + }) + + t.Run("empty token", func(t *testing.T) { + httpClient := &http.Client{} + client := NewClient(httpClient).WithAuthToken("") + validateAuthHeader(t, client.Client(), String("Bearer"), nil) + }) +} + +type tokenSourceFunc func() (string, error) + +func (f tokenSourceFunc) Token(context.Context) (string, error) { + return f() +} + +func TestWithTokenSource(t *testing.T) { + tokens := []string{"a", "b"} + source := tokenSourceFunc(func() (string, error) { + if len(tokens) == 0 { + return "", errors.New("no more tokens") + } + token := tokens[0] + tokens = tokens[1:] + return token, nil }) + httpClient := &http.Client{} + client := NewClient(httpClient).WithTokenSource(source) + validateAuthHeader(t, client.Client(), String("Bearer a"), nil) + validateAuthHeader(t, httpClient, nil, nil) + validateAuthHeader(t, client.WithTokenSource(nil).Client(), nil, nil) + validateAuthHeader(t, client.Client(), String("Bearer b"), nil) + validateAuthHeader(t, client.Client(), nil, String("could not get token: no more tokens")) } func TestWithEnterpriseURLs(t *testing.T) {