Skip to content

Commit

Permalink
Add RateLimitError type, detect and return it when appropriate.
Browse files Browse the repository at this point in the history
Fixes #152.

Add tests for RateLimitError.

Refactor populateRate into a more general parseRate.
  • Loading branch information
dmitshur authored and willnorris committed Feb 11, 2016
1 parent 1219390 commit c3f5683
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 10 deletions.
7 changes: 7 additions & 0 deletions github/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ been some time since the last API call and other clients have made subsequent
requests since then. You can always call RateLimits() directly to get the most
up-to-date rate limit data for the client.
To detect an API rate limit error, you can check if its type is *github.RateLimitError:
repos, _, err := client.Repositories.List("", nil)
if _, ok := err.(*github.RateLimitError); ok {
log.Println("hit rate limit")
}
Learn more about GitHub rate limiting at
http://developer.github.com/v3/#rate-limiting.
Expand Down
40 changes: 32 additions & 8 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ type Response struct {
func newResponse(r *http.Response) *Response {
response := &Response{Response: r}
response.populatePageValues()
response.populateRate()
response.Rate = parseRate(r)
return response
}

Expand Down Expand Up @@ -284,19 +284,21 @@ func (r *Response) populatePageValues() {
}
}

// populateRate parses the rate related headers and populates the response Rate.
func (r *Response) populateRate() {
// parseRate parses the rate related headers.
func parseRate(r *http.Response) Rate {
var rate Rate
if limit := r.Header.Get(headerRateLimit); limit != "" {
r.Rate.Limit, _ = strconv.Atoi(limit)
rate.Limit, _ = strconv.Atoi(limit)
}
if remaining := r.Header.Get(headerRateRemaining); remaining != "" {
r.Rate.Remaining, _ = strconv.Atoi(remaining)
rate.Remaining, _ = strconv.Atoi(remaining)
}
if reset := r.Header.Get(headerRateReset); reset != "" {
if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 {
r.Rate.Reset = Timestamp{time.Unix(v, 0)}
rate.Reset = Timestamp{time.Unix(v, 0)}
}
}
return rate
}

// Rate specifies the current rate limit for the client as determined by the
Expand Down Expand Up @@ -373,6 +375,20 @@ type TwoFactorAuthError ErrorResponse

func (r *TwoFactorAuthError) Error() string { return (*ErrorResponse)(r).Error() }

// RateLimitError occurs when GitHub returns 403 Forbidden response with a rate limit
// remaining value of 0, and error message starts with "API rate limit exceeded for ".
type RateLimitError struct {
Rate Rate // Rate specifies last known rate limit for the client
Response *http.Response // HTTP response that caused this error
Message string `json:"message"` // error message
}

func (r *RateLimitError) Error() string {
return fmt.Sprintf("%v %v: %d %v; rate reset in %v",
r.Response.Request.Method, sanitizeURL(r.Response.Request.URL),
r.Response.StatusCode, r.Message, r.Rate.Reset.Time.Sub(time.Now()))
}

// sanitizeURL redacts the client_secret parameter from the URL which may be
// exposed to the user, specifically in the ErrorResponse error message.
func sanitizeURL(uri *url.URL) *url.URL {
Expand Down Expand Up @@ -427,10 +443,18 @@ func CheckResponse(r *http.Response) error {
if err == nil && data != nil {
json.Unmarshal(data, errorResponse)
}
if r.StatusCode == http.StatusUnauthorized && strings.HasPrefix(r.Header.Get(headerOTP), "required") {
switch {
case r.StatusCode == http.StatusUnauthorized && strings.HasPrefix(r.Header.Get(headerOTP), "required"):
return (*TwoFactorAuthError)(errorResponse)
case r.StatusCode == http.StatusForbidden && r.Header.Get(headerRateRemaining) == "0" && strings.HasPrefix(errorResponse.Message, "API rate limit exceeded for "):
return &RateLimitError{
Rate: parseRate(r),
Response: errorResponse.Response,
Message: errorResponse.Message,
}
default:
return errorResponse
}
return errorResponse
}

// parseBoolResponse determines the boolean result from a GitHub API response.
Expand Down
52 changes: 50 additions & 2 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,11 @@ func TestDo_rateLimit(t *testing.T) {
}

req, _ := client.NewRequest("GET", "/", nil)
client.Do(req, nil)
_, err := client.Do(req, nil)

if err != nil {
t.Errorf("Do returned unexpected error: %v", err)
}
if got, want := client.Rate().Limit, 60; got != want {
t.Errorf("Client rate limit = %v, want %v", got, want)
}
Expand All @@ -416,8 +419,14 @@ func TestDo_rateLimit_errorResponse(t *testing.T) {
})

req, _ := client.NewRequest("GET", "/", nil)
client.Do(req, nil)
_, err := client.Do(req, nil)

if err == nil {
t.Error("Expected error to be returned.")
}
if _, ok := err.(*RateLimitError); ok {
t.Errorf("Did not expect a *RateLimitError error; got %#v.", err)
}
if got, want := client.Rate().Limit, 60; got != want {
t.Errorf("Client rate limit = %v, want %v", got, want)
}
Expand All @@ -430,6 +439,45 @@ func TestDo_rateLimit_errorResponse(t *testing.T) {
}
}

// Ensure *RateLimitError is returned when API rate limit is exceeded.
func TestDo_rateLimit_rateLimitError(t *testing.T) {
setup()
defer teardown()

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(headerRateLimit, "60")
w.Header().Add(headerRateRemaining, "0")
w.Header().Add(headerRateReset, "1372700873")
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(http.StatusForbidden)
fmt.Fprintln(w, `{
"message": "API rate limit exceeded for xxx.xxx.xxx.xxx. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)",
"documentation_url": "https://developer.github.com/v3/#rate-limiting"
}`)
})

req, _ := client.NewRequest("GET", "/", nil)
_, err := client.Do(req, nil)

if err == nil {
t.Error("Expected error to be returned.")
}
rateLimitErr, ok := err.(*RateLimitError)
if !ok {
t.Fatalf("Expected a *RateLimitError error; got %#v.", err)
}
if got, want := rateLimitErr.Rate.Limit, 60; got != want {
t.Errorf("rateLimitErr rate limit = %v, want %v", got, want)
}
if got, want := rateLimitErr.Rate.Remaining, 0; got != want {
t.Errorf("rateLimitErr rate remaining = %v, want %v", got, want)
}
reset := time.Date(2013, 7, 1, 17, 47, 53, 0, time.UTC)
if rateLimitErr.Rate.Reset.UTC() != reset {
t.Errorf("rateLimitErr rate reset = %v, want %v", client.Rate().Reset, reset)
}
}

func TestDo_noContent(t *testing.T) {
setup()
defer teardown()
Expand Down

0 comments on commit c3f5683

Please sign in to comment.