Skip to content

Commit

Permalink
add rate limit and fix middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
denpeshkov committed Feb 1, 2024
1 parent 0c69bde commit 1f1d8da
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 9 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ module github.com/denpeshkov/greenlight
go 1.22

require github.com/lib/pq v1.10.9

require golang.org/x/time v0.5.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
14 changes: 14 additions & 0 deletions internal/greenlight/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,17 @@ func NewConflictError(format string, args ...any) error {
func (e ConflictError) Error() string {
return e.Msg
}

type RateLimitError struct {
Msg string
}

func NewRateLimitError(format string, args ...any) error {
return RateLimitError{
Msg: fmt.Sprintf(format, args...),
}
}

func (e RateLimitError) Error() string {
return e.Msg
}
5 changes: 5 additions & 0 deletions internal/http/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ func ErrorStatusCode(err error) int {
return http.StatusUnprocessableEntity
case errors.As(err, &greenlight.ConflictError{}):
return http.StatusConflict
case errors.As(err, &greenlight.RateLimitError{}):
return http.StatusTooManyRequests
case errors.As(err, &greenlight.InternalError{}):
fallthrough
default:
Expand All @@ -44,6 +46,7 @@ func ErrorBody(err error) any {
var invErr greenlight.InvalidError
var intErr greenlight.InternalError
var cftErr greenlight.ConflictError
var rateErr greenlight.RateLimitError

switch {
case errors.As(err, &nfErr):
Expand All @@ -57,6 +60,8 @@ func ErrorBody(err error) any {
return ValidationErrorResponse{Msg: invErr.Msg, Fields: m}
case errors.As(err, &cftErr):
return ErrorResponse{Msg: cftErr.Msg}
case errors.As(err, &rateErr):
return ErrorResponse{Msg: rateErr.Msg}
case errors.As(err, &intErr):
fallthrough
default:
Expand Down
15 changes: 15 additions & 0 deletions internal/http/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"encoding/json"
"fmt"
"net/http"

"github.com/denpeshkov/greenlight/internal/greenlight"
"golang.org/x/time/rate"
)

// hijackResponseWriter records status of the HTTP response.
Expand Down Expand Up @@ -75,3 +78,15 @@ func (s *Server) recoverPanic(h http.Handler) http.Handler {
h.ServeHTTP(w, r)
})
}

func (s *Server) rateLimit(h http.Handler) http.Handler {
lim := rate.NewLimiter(2, 4)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !lim.Allow() {
s.Error(w, r, greenlight.NewRateLimitError("Rate limit exceeded."))
return
}

h.ServeHTTP(w, r)
})
}
12 changes: 3 additions & 9 deletions internal/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"log/slog"
"net/http"
"os"
"time"

"github.com/denpeshkov/greenlight/internal/greenlight"
)
Expand Down Expand Up @@ -47,7 +46,9 @@ func NewServer(addr string, opts ...Option) *Server {
func (s *Server) Start() error {
op := "http.Server.Start"

s.server.Handler = s
handler := s.recoverPanic(s.rateLimit(s.notFound(s.methodNotAllowed(s.router))))

s.server.Handler = handler
s.server.ErrorLog = slog.NewLogLogger(s.logger.Handler(), slog.LevelError)

s.server.IdleTimeout = s.opts.idleTimeout
Expand Down Expand Up @@ -75,13 +76,6 @@ func (s *Server) Close() error {
return nil
}

// ServerHTTP handles an HTTP request.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h := http.TimeoutHandler(s.router, 2*time.Second, "TIMEOUT!!!")
h = s.recoverPanic(s.notFound(s.methodNotAllowed(h)))
h.ServeHTTP(w, r)
}

func newLogger() *slog.Logger {
opts := slog.HandlerOptions{Level: slog.LevelDebug}
handler := slog.NewJSONHandler(os.Stderr, &opts)
Expand Down

0 comments on commit 1f1d8da

Please sign in to comment.