Skip to content

Commit

Permalink
Don't update httpClient passed to NewClient (google#3011)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAbides authored Dec 16, 2023
1 parent ee55955 commit 6d3dfc6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
8 changes: 7 additions & 1 deletion github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ type service struct {
}

// Client returns the http.Client used by this GitHub client.
// This should only be used for requests to the GitHub API because
// request headers will contain an authorization token.
func (c *Client) Client() *http.Client {
c.clientMu.Lock()
defer c.clientMu.Unlock()
Expand Down Expand Up @@ -315,7 +317,11 @@ func addOptions(s string, opts interface{}) (string, error) {
// an http.Client that will perform the authentication for you (such as that
// provided by the golang.org/x/oauth2 library).
func NewClient(httpClient *http.Client) *Client {
c := &Client{client: httpClient}
if httpClient == nil {
httpClient = &http.Client{}
}
httpClient2 := *httpClient
c := &Client{client: &httpClient2}
c.initialize()
return c
}
Expand Down
51 changes: 35 additions & 16 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,26 +321,45 @@ func TestClient(t *testing.T) {

func TestWithAuthToken(t *testing.T) {
token := "gh_test_token"
var gotAuthHeaderVals []string
wantAuthHeaderVals := []string{"Bearer " + token}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuthHeaderVals = r.Header["Authorization"]
}))
validate := func(c *Client) {

validate := func(t *testing.T, c *http.Client, token string) {
t.Helper()
gotAuthHeaderVals = nil
_, err := c.Client().Get(srv.URL)
if err != nil {
t.Fatalf("Get returned unexpected error: %v", err)
want := token
if want != "" {
want = "Bearer " + want
}
gotReq := false
headerVal := ""
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotReq = true
headerVal = r.Header.Get("Authorization")
}))
_, err := c.Get(srv.URL)
assertNilError(t, err)
if !gotReq {
t.Error("request not sent")
}
diff := cmp.Diff(wantAuthHeaderVals, gotAuthHeaderVals)
if diff != "" {
t.Errorf("Authorization header values mismatch (-want +got):\n%s", diff)
if headerVal != want {
t.Errorf("Authorization header is %v, want %v", headerVal, want)
}
}
validate(NewClient(nil).WithAuthToken(token))
validate(new(Client).WithAuthToken(token))
validate(NewTokenClient(context.Background(), token))

t.Run("zero-value Client", func(t *testing.T) {
c := new(Client).WithAuthToken(token)
validate(t, c.Client(), token)
})

t.Run("NewClient", func(t *testing.T) {
httpClient := &http.Client{}
client := NewClient(httpClient).WithAuthToken(token)
validate(t, client.Client(), token)
// make sure the original client isn't setting auth headers now
validate(t, httpClient, "")
})

t.Run("NewTokenClient", func(t *testing.T) {
validate(t, NewTokenClient(context.Background(), token).Client(), token)
})
}

func TestWithEnterpriseURLs(t *testing.T) {
Expand Down

0 comments on commit 6d3dfc6

Please sign in to comment.