diff --git a/internal/server/http.go b/internal/server/http.go index c665277..876f959 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,12 +6,12 @@ import ( "errors" "log/slog" "net/http" - "server/internal/middleware" "strings" "sync" "time" "server/internal/db" + "server/internal/middleware" util "server/util" ) @@ -20,8 +20,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -// HandlerMux wraps ServeMux and forces REST paths to lowercase -// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -47,10 +45,8 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -105,7 +101,7 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } @@ -119,7 +115,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { respStr, ok := resp.(string) if !ok { slog.Error("error: response is not a string", "error", slog.Any("err", err)) - http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } diff --git a/util/helpers.go b/util/helpers.go index 20a605a..3e3aea8 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -14,6 +14,37 @@ import ( "strings" ) +// Map of blocklisted commands +var blocklistedCommands = map[string]bool{ + "FLUSHALL": true, + "FLUSHDB": true, + "DUMP": true, + "ABORT": true, + "AUTH": true, + "CONFIG": true, + "SAVE": true, + "BGSAVE": true, + "BGREWRITEAOF": true, + "RESTORE": true, + "MULTI": true, + "EXEC": true, + "DISCARD": true, + "QWATCH": true, + "QUNWATCH": true, + "LATENCY": true, + "CLIENT": true, + "SLEEP": true, + "PERSIST": true, +} + +// BlockListedCommand checks if a command is blocklisted +func BlockListedCommand(cmd string) error { + if _, exists := blocklistedCommands[strings.ToUpper(cmd)]; exists { + return errors.New("ERR unknown command '" + cmd + "'") + } + return nil +} + // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -21,6 +52,11 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } + // Check if the command is blocklisted + if err := BlockListedCommand(command); err != nil { + return nil, err + } + args, err := newExtractor(r) if err != nil { return nil, err