diff --git a/internal/server/http.go b/internal/server/http.go index 3a08f4a..191c83a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" + "fmt" "log/slog" "net/http" "strings" @@ -21,6 +21,8 @@ type HTTPServer struct { DiceClient *db.DiceDB } +// HandlerMux wraps ServeMux and forces REST paths to lowercase +// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -46,8 +48,10 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -106,10 +110,11 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } - if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) - return - } +// Check if the command is blocklisted +if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return +} resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil {