From 2586ae49a80184c69a7a4abcbce7bcd87682377d Mon Sep 17 00:00:00 2001 From: Tim Reddehase Date: Thu, 25 Nov 2021 15:12:06 +0100 Subject: [PATCH] trim trailing slashes from base url When using a custom Github URL (e.g. for a Github Enterprise instance) one might specify it with a trailing slash. This can lead to URLs containing two successive slashes. Github will not ignore this, but instead return a HTTP 401 Forbidden response. It might be nice to address the issue by removing trailing slashes when joining with a string that has a leading slash. --- transport.go | 4 +- transport_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/transport.go b/transport.go index 5534d79..482a347 100644 --- a/transport.go +++ b/transport.go @@ -8,6 +8,7 @@ import ( "io" "io/ioutil" "net/http" + "strings" "sync" "time" @@ -159,7 +160,8 @@ func (t *Transport) refreshToken(ctx context.Context) error { return fmt.Errorf("could not convert installation token parameters into json: %s", err) } - req, err := http.NewRequest("POST", fmt.Sprintf("%s/app/installations/%v/access_tokens", t.BaseURL, t.installationID), body) + requestURL := fmt.Sprintf("%s/app/installations/%v/access_tokens", strings.TrimRight(t.BaseURL, "/"), t.installationID) + req, err := http.NewRequest("POST", requestURL, body) if err != nil { return fmt.Errorf("could not create request: %s", err) } diff --git a/transport_test.go b/transport_test.go index eae5e1a..8b1653b 100644 --- a/transport_test.go +++ b/transport_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "time" @@ -249,3 +250,97 @@ func TestRefreshTokenWithParameters(t *testing.T) { t.Fatalf("error calling RoundTrip: %v", err) } } + +func TestRefreshTokenWithTrailingSlashBaseURL(t *testing.T) { + installationTokenOptions := &github.InstallationTokenOptions{ + RepositoryIDs: []int64{1234}, + Permissions: &github.InstallationPermissions{ + Contents: github.String("write"), + Issues: github.String("read"), + }, + } + + tokenToBe := "token_string" + + // Convert io.ReadWriter to String without deleting body data. + wantBody, _ := GetReadWriter(installationTokenOptions) + wantBodyBytes := new(bytes.Buffer) + wantBodyBytes.ReadFrom(wantBody) + wantBodyString := wantBodyBytes.String() + + roundTripper := RoundTrip{ + rt: func(req *http.Request) (*http.Response, error) { + if strings.Contains(req.URL.Path, "//") { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("Forbidden\n")), + StatusCode: 401, + }, fmt.Errorf("Got simulated 401 Github Forbidden response") + } + + if req.URL.Path == "test_endpoint/" && req.Header.Get("Authorization") == fmt.Sprintf("token %s", tokenToBe) { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("Beautiful\n")), + StatusCode: 200, + }, nil + } + + // Convert io.ReadCloser to String without deleting body data. + var gotBodyBytes []byte + gotBodyBytes, _ = ioutil.ReadAll(req.Body) + req.Body = ioutil.NopCloser(bytes.NewBuffer(gotBodyBytes)) + gotBodyString := string(gotBodyBytes) + + // Compare request sent with request received. + if diff := cmp.Diff(wantBodyString, gotBodyString); diff != "" { + t.Errorf("HTTP body want->got: %s", diff) + } + + // Return acceptable access token. + accessToken := accessToken{ + Token: tokenToBe, + ExpiresAt: time.Now(), + Repositories: []github.Repository{{ + ID: github.Int64(1234), + }}, + Permissions: github.InstallationPermissions{ + Contents: github.String("write"), + Issues: github.String("read"), + }, + } + tokenReadWriter, err := GetReadWriter(accessToken) + if err != nil { + return nil, fmt.Errorf("error converting token into io.ReadWriter: %+v", err) + } + tokenBody := ioutil.NopCloser(tokenReadWriter) + return &http.Response{ + Body: tokenBody, + StatusCode: 200, + }, nil + }, + } + + tr, err := New(roundTripper, appID, installationID, key) + if err != nil { + t.Fatal("unexpected error:", err) + } + tr.InstallationTokenOptions = installationTokenOptions + tr.BaseURL = "http://localhost/github/api/v3/" + + // Convert InstallationTokenOptions into a ReadWriter to pass as an argument to http.NewRequest. + body, err := GetReadWriter(installationTokenOptions) + if err != nil { + t.Fatalf("error calling GetReadWriter: %v", err) + } + + req, err := http.NewRequest("POST", "http://localhost/test_endpoint/", body) + if err != nil { + t.Fatal("unexpected error:", err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatalf("error calling RoundTrip: %v", err) + } + if res.StatusCode != 200 { + t.Fatalf("Unexpected RoundTrip response code: %d", res.StatusCode) + } +}