Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow blocking until primary rate limit is reset #3117

Merged
merged 11 commits into from
May 1, 2024
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ if _, ok := err.(*github.AbuseRateLimitError); ok {
}
```

Alternatively, you can block until the rate limit is reset by using the `context.WithValue` method:

````go
repos, _, err := client.Repositories.List(context.WithValue(ctx, github.SleepUntilPrimaryRateLimitResetWhenRateLimited, true), "", nil)
```

You can use [go-github-ratelimit](https://github.com/gofri/go-github-ratelimit) to handle
secondary rate limit sleep-and-retry for you.

Expand Down
36 changes: 36 additions & 0 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ type requestContext uint8

const (
bypassRateLimitCheck requestContext = iota
SleepUntilPrimaryRateLimitResetWhenRateLimited
)

// BareDo sends an API request and lets you handle the api response. If an error
Expand Down Expand Up @@ -889,6 +890,15 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro
err = aerr
}

rateLimitError, ok := err.(*RateLimitError)
if ok && req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
if err := sleepUntilResetWithBuffer(req.Context(), rateLimitError.Rate.Reset.Time); err != nil {
return response, err
}
// retry the request once when the rate limit has reset
return c.BareDo(context.WithValue(req.Context(), SleepUntilPrimaryRateLimitResetWhenRateLimited, nil), req)
}

// Update the secondary rate limit if we hit it.
rerr, ok := err.(*AbuseRateLimitError)
if ok && rerr.RetryAfter != nil {
Expand Down Expand Up @@ -950,6 +960,18 @@ func (c *Client) checkRateLimitBeforeDo(req *http.Request, rateLimitCategory Rat
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader("")),
}

if req.Context().Value(SleepUntilPrimaryRateLimitResetWhenRateLimited) != nil {
if err := sleepUntilResetWithBuffer(req.Context(), rate.Reset.Time); err == nil {
return nil
}
return &RateLimitError{
Rate: rate,
Response: resp,
Message: fmt.Sprintf("Context cancelled while waiting for rate limit to reset until %v, not making remote request.", rate.Reset.Time),
}
}

return &RateLimitError{
Rate: rate,
Response: resp,
Expand Down Expand Up @@ -1514,6 +1536,20 @@ func formatRateReset(d time.Duration) string {
return fmt.Sprintf("[rate reset in %v]", timeString)
}

func sleepUntilResetWithBuffer(ctx context.Context, reset time.Time) error {
buffer := time.Second
timer := time.NewTimer(time.Until(reset) + buffer)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ctx.Err()
case <-timer.C:
}
return nil
}

// When using roundTripWithOptionalFollowRedirect, note that it
// is the responsibility of the caller to close the response body.
func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u string, maxRedirects int, opts ...RequestOption) (*http.Response, error) {
Expand Down
170 changes: 170 additions & 0 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,176 @@ func TestDo_rateLimit_ignoredFromCache(t *testing.T) {
}
}

// Ensure sleeps until the rate limit is reset when the client is rate limited.
func TestDo_rateLimit_sleepUntilResponseResetLimit(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

reset := time.Now().UTC().Add(time.Second)

var firstRequest = true
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if firstRequest {
firstRequest = false
w.Header().Set(headerRateLimit, "60")
w.Header().Set(headerRateRemaining, "0")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
fmt.Fprintln(w, `{
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
}`)
return
}
w.Header().Set(headerRateLimit, "5000")
w.Header().Set(headerRateRemaining, "5000")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, `{}`)
})

req, _ := client.NewRequest("GET", ".", nil)
ctx := context.Background()
resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
if err != nil {
t.Errorf("Do returned unexpected error: %v", err)
}
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("Response status code = %v, want %v", got, want)
}
}

// Ensure tries to sleep until the rate limit is reset when the client is rate limited, but only once.
func TestDo_rateLimit_sleepUntilResponseResetLimitRetryOnce(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

reset := time.Now().UTC().Add(time.Second)

requestCount := 0
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.Header().Set(headerRateLimit, "60")
w.Header().Set(headerRateRemaining, "0")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
fmt.Fprintln(w, `{
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
}`)
})

req, _ := client.NewRequest("GET", ".", nil)
ctx := context.Background()
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
if err == nil {
t.Error("Expected error to be returned.")
}
if got, want := requestCount, 2; got != want {
t.Errorf("Expected 2 requests, got %d", got)
}
}

// Ensure a network call is not made when it's known that API rate limit is still exceeded.
func TestDo_rateLimit_sleepUntilClientResetLimit(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

reset := time.Now().UTC().Add(time.Second)
client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}}
requestCount := 0
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.Header().Set(headerRateLimit, "5000")
w.Header().Set(headerRateRemaining, "5000")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, `{}`)
})
req, _ := client.NewRequest("GET", ".", nil)
ctx := context.Background()
resp, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
if err != nil {
t.Errorf("Do returned unexpected error: %v", err)
}
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("Response status code = %v, want %v", got, want)
}
if got, want := requestCount, 1; got != want {
t.Errorf("Expected 1 request, got %d", got)
}
}

// Ensure sleep is aborted when the context is cancelled.
func TestDo_rateLimit_abortSleepContextCancelled(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

// We use a 1 minute reset time to ensure the sleep is not completed.
reset := time.Now().UTC().Add(time.Minute)
requestCount := 0
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.Header().Set(headerRateLimit, "60")
w.Header().Set(headerRateRemaining, "0")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
fmt.Fprintln(w, `{
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
"documentation_url": "https://docs.github.com/en/rest/overview/resources-in-the-rest-api#abuse-rate-limits"
}`)
})

req, _ := client.NewRequest("GET", ".", nil)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
if !errors.Is(err, context.DeadlineExceeded) {
t.Error("Expected context deadline exceeded error.")
}
if got, want := requestCount, 1; got != want {
t.Errorf("Expected 1 requests, got %d", got)
}
}

// Ensure sleep is aborted when the context is cancelled on initial request.
func TestDo_rateLimit_abortSleepContextCancelledClientLimit(t *testing.T) {
client, mux, _, teardown := setup()
defer teardown()

reset := time.Now().UTC().Add(time.Minute)
client.rateLimits[CoreCategory] = Rate{Limit: 5000, Remaining: 0, Reset: Timestamp{reset}}
requestCount := 0
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
requestCount++
w.Header().Set(headerRateLimit, "5000")
w.Header().Set(headerRateRemaining, "5000")
w.Header().Set(headerRateReset, fmt.Sprint(reset.Add(time.Hour).Unix()))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprintln(w, `{}`)
})
req, _ := client.NewRequest("GET", ".", nil)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.Do(context.WithValue(ctx, SleepUntilPrimaryRateLimitResetWhenRateLimited, true), req, nil)
rateLimitError, ok := err.(*RateLimitError)
if !ok {
t.Fatalf("Expected a *rateLimitError error; got %#v.", err)
}
if got, wantSuffix := rateLimitError.Message, "Context cancelled while waiting for rate limit to reset until"; !strings.HasPrefix(got, wantSuffix) {
t.Errorf("Expected request to be prevented because context cancellation, got: %v.", got)
}
if got, want := requestCount, 0; got != want {
t.Errorf("Expected 1 requests, got %d", got)
}
}

// Ensure *AbuseRateLimitError is returned when the response indicates that
// the client has triggered an abuse detection mechanism.
func TestDo_rateLimit_abuseRateLimitError(t *testing.T) {
Expand Down
Loading