From 911438f3ee1fd2e4fe705e507f593fd9bd4216ed Mon Sep 17 00:00:00 2001 From: Tarun Kantiwal <48859385+tarun-29@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:02:58 +0530 Subject: [PATCH] Add trailing slash middleware to prevent unexpected API crash (#16) --- internal/middleware/trailing_slash_test.go | 64 ++++++++++++++++++++++ internal/middleware/trailingslash.go | 23 ++++++++ internal/server/httpServer.go | 8 ++- 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/trailing_slash_test.go create mode 100644 internal/middleware/trailingslash.go diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go new file mode 100644 index 0000000..203fb1a --- /dev/null +++ b/internal/middleware/trailing_slash_test.go @@ -0,0 +1,64 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/internal/middleware" + "testing" +) + +func TestTrailingSlashMiddleware(t *testing.T) { + + handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + requestURL string + expectedCode int + expectedUrl string + }{ + { + name: "url with trailing slash", + requestURL: "/example/", + expectedCode: http.StatusMovedPermanently, + expectedUrl: "/example", + }, + { + name: "url without trailing slash", + requestURL: "/example", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "root url with trailing slash", + requestURL: "/", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "URL with Query Parameters", + requestURL: "/example?query=1", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.requestURL, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != tt.expectedCode { + t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { + t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + } + }) + } +} diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go new file mode 100644 index 0000000..09ac1b7 --- /dev/null +++ b/internal/middleware/trailingslash.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net/http" + "strings" +) + +func TrailingSlashMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { + // remove slash + newPath := strings.TrimSuffix(r.URL.Path, "/") + // if query params exist append them + newURL := newPath + if r.URL.RawQuery != "" { + newURL += "?" + r.URL.RawQuery + } + http.Redirect(w, r, newURL, http.StatusMovedPermanently) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 2e8e1a8..476f823 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -42,9 +42,11 @@ func errorResponse(response string) string { func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Convert the path to lowercase before passing to the underlying mux. - r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter - cim.rateLimiter(w, r, cim.mux) + middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter + cim.rateLimiter(w, r, cim.mux) + })).ServeHTTP(w, r) } func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer {