From 6fd9753227f5e518f4254ce58b5a6621685b1d2e Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Fri, 4 Oct 2024 19:26:55 +0530 Subject: [PATCH 01/15] blacklist-added --- go.mod | 3 ++- go.sum | 6 ++++-- internal/server/httpServer.go | 9 ++++++++- pkg/util/blacklist.go | 22 ++++++++++++++++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 pkg/util/blacklist.go diff --git a/go.mod b/go.mod index 26906be..dbddc32 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,9 @@ module server go 1.22.5 +require github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 + require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 // indirect ) diff --git a/go.sum b/go.sum index af5b6c5..eac4db4 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,9 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk= diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 68e4c86..c899e3f 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -7,11 +7,11 @@ import ( "fmt" "log" "net/http" - "server/internal/middleware" "strings" "sync" "time" + "server/internal/middleware" "server/internal/db" util "server/pkg/util" ) @@ -101,6 +101,13 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } + // Check if the command is blacklisted + if err := util.IsBlacklistedCommand(diceCmd.Cmd); err != nil { + // Return the error message in the specified format + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return + } + resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) diff --git a/pkg/util/blacklist.go b/pkg/util/blacklist.go new file mode 100644 index 0000000..0a84628 --- /dev/null +++ b/pkg/util/blacklist.go @@ -0,0 +1,22 @@ +package helpers + +import ( + "errors" + "strings" +) + +var blacklistedCommands = []string{ + "FLUSHALL", "FLUSHDB", "DUMP", "ABORT", "AUTH", "CONFIG", "SAVE", "BGSAVE", + "BGREWRITEAOF", "RESTORE", "MULTI", "EXEC", "DISCARD", "QWATCH", "QUNWATCH", + "LATENCY", "CLIENT", "SLEEP", "PERSIST", +} + +// IsBlacklistedCommand checks if a command is blacklisted +func IsBlacklistedCommand(cmd string) error { + for _, blacklistedCmd := range blacklistedCommands { + if strings.ToUpper(cmd) == blacklistedCmd { + return errors.New("command is blacklisted") + } + } + return nil +} From b4276c548192df0a030c3dd02c6733ef16d26fe6 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Fri, 4 Oct 2024 19:29:54 +0530 Subject: [PATCH 02/15] fixes --- go.mod | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go.mod b/go.mod index dbddc32..26906be 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,8 @@ module server go 1.22.5 -require github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 - require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 // indirect ) From d83c716c5bfa6e629ab06d40e25a08040c74eda4 Mon Sep 17 00:00:00 2001 From: Tarun Kantiwal <48859385+tarun-29@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:02:58 +0530 Subject: [PATCH 03/15] Add trailing slash middleware to prevent unexpected API crash (#16) --- internal/middleware/trailing_slash_test.go | 64 ++++++++++++++++++++++ internal/middleware/trailingslash.go | 23 ++++++++ internal/server/httpServer.go | 8 ++- 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/trailing_slash_test.go create mode 100644 internal/middleware/trailingslash.go diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go new file mode 100644 index 0000000..203fb1a --- /dev/null +++ b/internal/middleware/trailing_slash_test.go @@ -0,0 +1,64 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/internal/middleware" + "testing" +) + +func TestTrailingSlashMiddleware(t *testing.T) { + + handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + requestURL string + expectedCode int + expectedUrl string + }{ + { + name: "url with trailing slash", + requestURL: "/example/", + expectedCode: http.StatusMovedPermanently, + expectedUrl: "/example", + }, + { + name: "url without trailing slash", + requestURL: "/example", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "root url with trailing slash", + requestURL: "/", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "URL with Query Parameters", + requestURL: "/example?query=1", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.requestURL, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != tt.expectedCode { + t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { + t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + } + }) + } +} diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go new file mode 100644 index 0000000..09ac1b7 --- /dev/null +++ b/internal/middleware/trailingslash.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net/http" + "strings" +) + +func TrailingSlashMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { + // remove slash + newPath := strings.TrimSuffix(r.URL.Path, "/") + // if query params exist append them + newURL := newPath + if r.URL.RawQuery != "" { + newURL += "?" + r.URL.RawQuery + } + http.Redirect(w, r, newURL, http.StatusMovedPermanently) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index f0ebcf7..4bdddec 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -42,9 +42,11 @@ func errorResponse(response string) string { func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Convert the path to lowercase before passing to the underlying mux. - r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter - cim.rateLimiter(w, r, cim.mux) + middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter + cim.rateLimiter(w, r, cim.mux) + })).ServeHTTP(w, r) } func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit int64, window float64) *HTTPServer { From 11565acf350db9e1af085823b52ff08122a3dda1 Mon Sep 17 00:00:00 2001 From: rishav vajpayee <46602331+rishavvajpayee@users.noreply.github.com> Date: Sat, 5 Oct 2024 00:39:57 +0530 Subject: [PATCH 04/15] #21: Refactored repo for consistency (#24) --- config/config.go | 20 +- internal/db/dicedb.go | 14 +- internal/middleware/ratelimiter.go | 17 +- internal/middleware/trailingslash.go | 2 - internal/server/{httpServer.go => http.go} | 39 ++-- .../{mock_dicedb.go => dicedb_mock.go} | 0 .../ratelimiter_integration_test.go | 10 +- .../integration}/trailing_slash_test.go | 14 +- .../tests/stress/ratelimiter_stress_test.go | 10 +- main.go | 12 +- pkg/util/helpers.go | 185 ----------------- {internal => util}/cmds/cmds.go | 0 util/helpers.go | 191 ++++++++++++++++++ 13 files changed, 257 insertions(+), 257 deletions(-) rename internal/server/{httpServer.go => http.go} (74%) rename internal/tests/dbmocks/{mock_dicedb.go => dicedb_mock.go} (100%) rename internal/{middleware => tests/integration}/trailing_slash_test.go (72%) delete mode 100644 pkg/util/helpers.go rename {internal => util}/cmds/cmds.go (100%) create mode 100644 util/helpers.go diff --git a/config/config.go b/config/config.go index 73c1f01..b0f0b9e 100644 --- a/config/config.go +++ b/config/config.go @@ -11,11 +11,11 @@ import ( // Config holds the application configuration type Config struct { - DiceAddr string - ServerPort string - RequestLimit int64 // Field for the request limit - RequestWindow float64 // Field for the time window in float64 - AllowedOrigins []string // Field for the allowed origins + DiceDBAddr string + ServerPort string + RequestLimitPerMin int64 // Field for the request limit + RequestWindowSec float64 // Field for the time window in float64 + AllowedOrigins []string // Field for the allowed origins } // LoadConfig loads the application configuration from environment variables or defaults @@ -26,11 +26,11 @@ func LoadConfig() *Config { } return &Config{ - DiceAddr: getEnv("DICE_ADDR", "localhost:7379"), // Default Dice address - ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port - RequestLimit: getEnvInt("REQUEST_LIMIT", 1000), // Default request limit - RequestWindow: getEnvFloat64("REQUEST_WINDOW", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + DiceDBAddr: getEnv("DICEDB_ADDR", "localhost:7379"), // Default DiceDB address + ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port + RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit + RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins } } diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index c811aac..1701cf6 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -11,10 +11,10 @@ import ( "log/slog" "os" "server/config" - "server/internal/cmds" + "server/util/cmds" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) const ( @@ -22,7 +22,7 @@ const ( ) type DiceDB struct { - Client *dice.Client + Client *dicedb.Client Ctx context.Context } @@ -36,13 +36,13 @@ func (db *DiceDB) CloseDiceDB() { } func InitDiceClient(configValue *config.Config) (*DiceDB, error) { - diceClient := dice.NewClient(&dice.Options{ - Addr: configValue.DiceAddr, + diceClient := dicedb.NewClient(&dicedb.Options{ + Addr: configValue.DiceDBAddr, DialTimeout: 10 * time.Second, MaxRetries: 10, }) - // Ping the dice client to verify the connection + // Ping the dicedb client to verify the connection err := diceClient.Ping(context.Background()).Err() if err != nil { return nil, err @@ -64,7 +64,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err val, err := db.getKey(command.Args[0]) switch { - case errors.Is(err, dice.Nil): + case errors.Is(err, dicedb.Nil): return nil, errors.New("key does not exist") case err != nil: return nil, fmt.Errorf("get failed %v", err) diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c8da643..c852a26 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -12,19 +12,17 @@ import ( "strings" "time" - dice "github.com/dicedb/go-dice" + dicedb "github.com/dicedb/go-dice" ) // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // Skip rate limiting for non-command endpoints - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -36,7 +34,7 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float // Fetch the current request count val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dice.Nil) { + if err != nil && !errors.Is(err, dicedb.Nil) { slog.Error("Error fetching request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return @@ -74,25 +72,19 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } } - // Log the successful request increment slog.Info("Request processed", "count", requestCount+1) - - // Call the next handler next.ServeHTTP(w, r) }) } func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Enable CORS for requests enableCors(w, r) - - // Set a request context with a timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/cli/") { + if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } @@ -144,7 +136,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } - // Log the successful request and pass control to the next handler slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 09ac1b7..80871ce 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -8,9 +8,7 @@ import ( func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // remove slash newPath := strings.TrimSuffix(r.URL.Path, "/") - // if query params exist append them newURL := newPath if r.URL.RawQuery != "" { newURL += "?" + r.URL.RawQuery diff --git a/internal/server/httpServer.go b/internal/server/http.go similarity index 74% rename from internal/server/httpServer.go rename to internal/server/http.go index 4bdddec..b95eb63 100644 --- a/internal/server/httpServer.go +++ b/internal/server/http.go @@ -4,8 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" - "log" + "log/slog" "net/http" "strings" "sync" @@ -13,7 +12,7 @@ import ( "server/internal/middleware" "server/internal/db" - util "server/pkg/util" + util "server/util" ) type HTTPServer struct { @@ -37,7 +36,13 @@ type HTTPErrorResponse struct { } func errorResponse(response string) string { - return fmt.Sprintf("{\"error\": %q}", response) + errorMessage := map[string]string{"error": response} + jsonResponse, err := json.Marshal(errorMessage) + if err != nil { + slog.Error("Error marshaling response: %v", slog.Any("err", err)) + return `{"error": "internal server error"}` + } + return string(jsonResponse) } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -73,33 +78,33 @@ func (s *HTTPServer) Run(ctx context.Context) error { wg.Add(1) go func() { defer wg.Done() - log.Printf("Starting server at %s\n", s.httpServer.Addr) + slog.Info("starting server at", slog.String("addr", s.httpServer.Addr)) if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("HTTP server error: %v", err) + slog.Error("http server error: %v", slog.Any("err", err)) } }() <-ctx.Done() - log.Println("Shutting down server...") + slog.Info("shutting down server...") return s.Shutdown() } func (s *HTTPServer) Shutdown() error { if err := s.DiceClient.Client.Close(); err != nil { - log.Printf("Failed to close dice client: %v", err) + slog.Error("failed to close dicedb client: %v", slog.Any("err", err)) } return s.httpServer.Shutdown(context.Background()) } func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Server is running"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) } func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("Error parsing HTTP request"), http.StatusBadRequest) + http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) return } @@ -112,32 +117,32 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("Error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) return } respStr, ok := resp.(string) if !ok { - log.Println("Error: response is not a string", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error: response is not a string", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) return } httpResponse := HTTPResponse{Data: respStr} responseJSON, err := json.Marshal(httpResponse) if err != nil { - log.Println("Error marshaling response to JSON", "error", err) - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + slog.Error("error marshaling response to json", "error", slog.Any("err", err)) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } _, err = w.Write(responseJSON) if err != nil { - http.Error(w, errorResponse("Internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } } func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "Search results"}) + util.JSONResponse(w, http.StatusOK, map[string]string{"message": "search results"}) } diff --git a/internal/tests/dbmocks/mock_dicedb.go b/internal/tests/dbmocks/dicedb_mock.go similarity index 100% rename from internal/tests/dbmocks/mock_dicedb.go rename to internal/tests/dbmocks/dicedb_mock.go diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go index 421dafb..7faca79 100644 --- a/internal/tests/integration/ratelimiter_integration_test.go +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" config "server/config" - util "server/pkg/util" + util "server/util" "testing" "github.com/stretchr/testify/require" @@ -12,8 +12,8 @@ import ( func TestRateLimiterWithinLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) @@ -25,8 +25,8 @@ func TestRateLimiterWithinLimit(t *testing.T) { func TestRateLimiterExceedsLimit(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec w, r, rateLimiter := util.SetupRateLimiter(limit, window) diff --git a/internal/middleware/trailing_slash_test.go b/internal/tests/integration/trailing_slash_test.go similarity index 72% rename from internal/middleware/trailing_slash_test.go rename to internal/tests/integration/trailing_slash_test.go index 203fb1a..972158d 100644 --- a/internal/middleware/trailing_slash_test.go +++ b/internal/tests/integration/trailing_slash_test.go @@ -45,19 +45,19 @@ func TestTrailingSlashMiddleware(t *testing.T) { }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", tt.requestURL, nil) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req := httptest.NewRequest("GET", test.requestURL, nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) - if w.Code != tt.expectedCode { - t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + if w.Code != test.expectedCode { + t.Errorf("expected status %d, got %d", test.expectedCode, w.Code) } - if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { - t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + if test.expectedUrl != "" && w.Header().Get("Location") != test.expectedUrl { + t.Errorf("expected location %s, got %s", test.expectedUrl, w.Header().Get("Location")) } }) } diff --git a/internal/tests/stress/ratelimiter_stress_test.go b/internal/tests/stress/ratelimiter_stress_test.go index 7e37054..75c14f1 100644 --- a/internal/tests/stress/ratelimiter_stress_test.go +++ b/internal/tests/stress/ratelimiter_stress_test.go @@ -4,7 +4,7 @@ import ( "net/http" "net/http/httptest" "server/config" - util "server/pkg/util" + util "server/util" "sync" "testing" "time" @@ -14,13 +14,13 @@ import ( func TestRateLimiterUnderStress(t *testing.T) { configValue := config.LoadConfig() - limit := configValue.RequestLimit - window := configValue.RequestWindow + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec _, r, rateLimiter := util.SetupRateLimiter(limit, window) var wg sync.WaitGroup - var numRequests int64 = limit // add some extra requests to ensure we don't hit the limit + var numRequests int64 = limit successCount := int64(0) failCount := int64(0) var mu sync.Mutex @@ -43,5 +43,5 @@ func TestRateLimiterUnderStress(t *testing.T) { }() } wg.Wait() - require.Equal(t, limit, successCount, "Should succeed for exactly limit requests") + require.Equal(t, limit, successCount, "should succeed for exactly limit requests") } diff --git a/main.go b/main.go index 09a082f..b01de76 100644 --- a/main.go +++ b/main.go @@ -2,25 +2,25 @@ package main import ( "context" - "log" + "log/slog" "net/http" "server/config" "server/internal/db" - "server/internal/server" // Import the new package for HTTPServer + "server/internal/server" ) func main() { configValue := config.LoadConfig() diceClient, err := db.InitDiceClient(configValue) if err != nil { - log.Fatalf("Failed to initialize dice client: %v", err) + slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) } // Create mux and register routes mux := http.NewServeMux() - httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimit, configValue.RequestWindow) + httpServer := server.NewHTTPServer(":8080", mux, diceClient, configValue.RequestLimitPerMin, configValue.RequestWindowSec) mux.HandleFunc("/health", httpServer.HealthCheck) - mux.HandleFunc("/cli/{cmd}", httpServer.CliHandler) + mux.HandleFunc("/shell/exec/{cmd}", httpServer.CliHandler) mux.HandleFunc("/search", httpServer.SearchHandler) // Graceful shutdown context @@ -29,6 +29,6 @@ func main() { // Run the HTTP Server if err := httpServer.Run(ctx); err != nil { - log.Printf("Server failed: %v\n", err) + slog.Error("server failed: %v\n", slog.Any("err", err)) } } diff --git a/pkg/util/helpers.go b/pkg/util/helpers.go deleted file mode 100644 index c86cef0..0000000 --- a/pkg/util/helpers.go +++ /dev/null @@ -1,185 +0,0 @@ -package helpers - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "server/internal/cmds" - "server/internal/middleware" - db "server/internal/tests/dbmocks" - "strings" -) - -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { - command := strings.TrimPrefix(r.URL.Path, "/cli/") - if command == "" { - return nil, errors.New("invalid command") - } - - command = strings.ToUpper(command) - var args []string - - // Extract query parameters - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - // Step 1: Handle JSON body if present - if r.Body != nil { - body, err := io.ReadAll(r.Body) - if err != nil { - return nil, err - } - - if len(body) > 0 { - var jsonBody map[string]interface{} - if err := json.Unmarshal(body, &jsonBody); err != nil { - return nil, err - } - - if len(jsonBody) == 0 { - return nil, fmt.Errorf("empty JSON object") - } - - // Define keys to exclude and process their values first - // Update as we support more commands - var priorityKeys = []string{ - Key, - Keys, - Field, - Path, - Value, - Values, - Seconds, - User, - Password, - KeyValues, - QwatchQuery, - Offset, - Member, - Members, - } - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - if key == Keys { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Values { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - // MultiKey operations - if key == KeyValues { - // Handle KeyValues separately - for k, v := range val.(map[string]interface{}) { - args = append(args, k, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - if key == Members { - for _, v := range val.([]interface{}) { - args = append(args, fmt.Sprintf("%v", v)) - } - delete(jsonBody, key) - continue - } - args = append(args, fmt.Sprintf("%v", val)) - delete(jsonBody, key) - } - } - - // Process remaining keys in the JSON body - for key, val := range jsonBody { - switch v := val.(type) { - case string: - // Handle unary operations like 'nx' where value is "true" - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - // Marshal nested JSON structures back into a string - jsonValue, err := json.Marshal(v) - if err != nil { - return nil, err - } - args = append(args, string(jsonValue)) - default: - args = append(args, key) - // Append other types as strings - value := fmt.Sprintf("%v", v) - if !strings.EqualFold(value, True) { - args = append(args, value) - } - } - } - } - } - - // Step 2: Return the constructed Redis command - return &cmds.CommandRequest{ - Cmd: command, - Args: args, - }, nil -} - -func JSONResponse(w http.ResponseWriter, status int, data interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - if err := json.NewEncoder(w).Encode(data); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } -} - -func MockHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("OK")); err != nil { - log.Fatalf("Failed to write response: %v", err) - } -} - -func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { - mockClient := db.NewDiceDBMock() - - r := httptest.NewRequest("GET", "/cli/somecommand", http.NoBody) - w := httptest.NewRecorder() - - rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) - - return w, r, rateLimiter -} diff --git a/internal/cmds/cmds.go b/util/cmds/cmds.go similarity index 100% rename from internal/cmds/cmds.go rename to util/cmds/cmds.go diff --git a/util/helpers.go b/util/helpers.go new file mode 100644 index 0000000..20b3c2b --- /dev/null +++ b/util/helpers.go @@ -0,0 +1,191 @@ +package utils + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "server/internal/middleware" + db "server/internal/tests/dbmocks" + cmds "server/util/cmds" + "strings" +) + +const ( + Key = "key" + Keys = "keys" + KeyPrefix = "key_prefix" + Field = "field" + Path = "path" + Value = "value" + Values = "values" + User = "user" + Password = "password" + Seconds = "seconds" + KeyValues = "key_values" + True = "true" + QwatchQuery = "query" + Offset = "offset" + Member = "member" + Members = "members" + + JSONIngest string = "JSON.INGEST" +) + +var priorityKeys = []string{ + Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, +} + +// ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands +func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { + command := extractCommand(r.URL.Path) + if command == "" { + return nil, errors.New("invalid command") + } + + args, err := extractArgsFromRequest(r, command) + if err != nil { + return nil, err + } + + return &cmds.CommandRequest{ + Cmd: command, + Args: args, + }, nil +} + +func extractCommand(path string) string { + command := strings.TrimPrefix(path, "/shell/exec/") + return strings.ToUpper(command) +} + +func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { + var args []string + queryParams := r.URL.Query() + keyPrefix := queryParams.Get(KeyPrefix) + + if keyPrefix != "" && command == JSONIngest { + args = append(args, keyPrefix) + } + + if r.Body != nil { + bodyArgs, err := parseRequestBody(r.Body) + if err != nil { + return nil, err + } + args = append(args, bodyArgs...) + } + + return args, nil +} + +func parseRequestBody(body io.ReadCloser) ([]string, error) { + var args []string + bodyContent, err := io.ReadAll(body) + if err != nil { + return nil, err + } + + if len(bodyContent) == 0 { + return args, nil + } + + var jsonBody map[string]interface{} + if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { + return nil, err + } + + if len(jsonBody) == 0 { + return nil, fmt.Errorf("empty JSON object") + } + + args = append(args, extractPriorityArgs(jsonBody)...) + args = append(args, extractRemainingArgs(jsonBody)...) + + return args, nil +} + +func extractPriorityArgs(jsonBody map[string]interface{}) []string { + var args []string + for _, key := range priorityKeys { + if val, exists := jsonBody[key]; exists { + switch key { + case Keys, Values, Members: + args = append(args, convertListToStrings(val.([]interface{}))...) + case KeyValues: + args = append(args, convertMapToStrings(val.(map[string]interface{}))...) + default: + args = append(args, fmt.Sprintf("%v", val)) + } + delete(jsonBody, key) + } + } + return args +} + +func extractRemainingArgs(jsonBody map[string]interface{}) []string { + var args []string + for key, val := range jsonBody { + switch v := val.(type) { + case string: + args = append(args, key) + if !strings.EqualFold(v, True) { + args = append(args, v) + } + case map[string]interface{}, []interface{}: + jsonValue, _ := json.Marshal(v) + args = append(args, string(jsonValue)) + default: + args = append(args, key, fmt.Sprintf("%v", v)) + } + } + return args +} + +func convertListToStrings(list []interface{}) []string { + var result []string + for _, v := range list { + result = append(result, fmt.Sprintf("%v", v)) + } + return result +} + +func convertMapToStrings(m map[string]interface{}) []string { + var result []string + for k, v := range m { + result = append(result, k, fmt.Sprintf("%v", v)) + } + return result +} + +// JSONResponse sends a JSON response to the client +func JSONResponse(w http.ResponseWriter, status int, data interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// MockHandler is a basic mock handler for testing +func MockHandler(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("OK")); err != nil { + slog.Error("Failed to write response: %v", slog.Any("err", err)) + } +} + +// SetupRateLimiter sets up a rate limiter for testing purposes +func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, *http.Request, http.Handler) { + mockClient := db.NewDiceDBMock() + + r := httptest.NewRequest("GET", "/shell/exec/get", http.NoBody) + w := httptest.NewRecorder() + + rateLimiter := middleware.MockRateLimiter(mockClient, http.HandlerFunc(MockHandler), limit, window) + + return w, r, rateLimiter +} From 094a94d17e0b214b293ca0cad9214847ddbff0c7 Mon Sep 17 00:00:00 2001 From: Prashant Shubham Date: Sat, 5 Oct 2024 23:22:39 +0530 Subject: [PATCH 05/15] Adding support for generic command execution (#26) --- config/config.go | 2 +- internal/db/commands.go | 16 ---- internal/db/dicedb.go | 68 +++++++---------- internal/middleware/cors.go | 13 +++- internal/middleware/ratelimiter.go | 10 ++- internal/server/http.go | 3 +- main.go | 2 + util/cmds/cmds.go | 4 +- util/helpers.go | 114 +++-------------------------- 9 files changed, 62 insertions(+), 170 deletions(-) delete mode 100644 internal/db/commands.go diff --git a/config/config.go b/config/config.go index b0f0b9e..433c5f9 100644 --- a/config/config.go +++ b/config/config.go @@ -30,7 +30,7 @@ func LoadConfig() *Config { ServerPort: getEnv("SERVER_PORT", ":8080"), // Default server port RequestLimitPerMin: getEnvInt("REQUEST_LIMIT_PER_MIN", 1000), // Default request limit RequestWindowSec: getEnvFloat64("REQUEST_WINDOW_SEC", 60), // Default request window in float64 - AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:8080"}), // Default allowed origins + AllowedOrigins: getEnvArray("ALLOWED_ORIGINS", []string{"http://localhost:3000"}), // Default allowed origins } } diff --git a/internal/db/commands.go b/internal/db/commands.go deleted file mode 100644 index f78b616..0000000 --- a/internal/db/commands.go +++ /dev/null @@ -1,16 +0,0 @@ -package db - -func (db *DiceDB) getKey(key string) (string, error) { - val, err := db.Client.Get(db.Ctx, key).Result() - return val, err -} - -func (db *DiceDB) setKey(key, value string) error { - err := db.Client.Set(db.Ctx, key, value, 0).Err() - return err -} - -func (db *DiceDB) deleteKeys(keys []string) error { - err := db.Client.Del(db.Ctx, keys...).Err() - return err -} diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index 1701cf6..71d4901 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -17,10 +17,6 @@ import ( dicedb "github.com/dicedb/go-dice" ) -const ( - RespOK = "OK" -) - type DiceDB struct { Client *dicedb.Client Ctx context.Context @@ -56,47 +52,35 @@ func InitDiceClient(configValue *config.Config) (*DiceDB, error) { // ExecuteCommand executes a command based on the input func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, error) { - switch command.Cmd { - case "GET": - if len(command.Args) != 1 { - return nil, errors.New("invalid args") - } - - val, err := db.getKey(command.Args[0]) - switch { - case errors.Is(err, dicedb.Nil): - return nil, errors.New("key does not exist") - case err != nil: - return nil, fmt.Errorf("get failed %v", err) - } - - return val, nil - - case "SET": - if len(command.Args) < 2 { - return nil, errors.New("key is required") - } - - err := db.setKey(command.Args[0], command.Args[1]) - if err != nil { - return nil, errors.New("failed to set key") - } - - return RespOK, nil - - case "DEL": - if len(command.Args) == 0 { - return nil, errors.New("at least one key is required") - } + args := make([]interface{}, 0, len(command.Args)+1) + args = append(args, command.Cmd) + for _, arg := range command.Args { + args = append(args, arg) + } - err := db.deleteKeys(command.Args) - if err != nil { - return nil, errors.New("failed to delete keys") - } + res, err := db.Client.Do(db.Ctx, args...).Result() + if errors.Is(err, dicedb.Nil) { + return nil, errors.New("(nil)") + } - return RespOK, nil + if err != nil { + return nil, fmt.Errorf("(error) %v", err) + } + // Print the result based on its type + switch v := res.(type) { + case string: + return v, nil + case []byte: + return string(v), nil + case []interface{}: + case int64: + return fmt.Sprintf("%v", v), nil + case nil: + return "(nil)", nil default: - return nil, errors.New("unknown command") + return fmt.Sprintf("%v", v), nil } + + return nil, fmt.Errorf("(error) invalid result type") } diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 63bc5d9..3e75120 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -5,30 +5,37 @@ import ( "server/config" ) -func enableCors(w http.ResponseWriter, r *http.Request) { +// Updated enableCors function to return a boolean indicating if OPTIONS was handled +func handleCors(w http.ResponseWriter, r *http.Request) bool { configValue := config.LoadConfig() allAllowedOrigins := configValue.AllowedOrigins origin := r.Header.Get("Origin") allowed := false + for _, allowedOrigin := range allAllowedOrigins { if origin == allowedOrigin || allowedOrigin == "*" || origin == "" { allowed = true break } } + if !allowed { http.Error(w, "CORS: origin not allowed", http.StatusForbidden) - return + return true } w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Content-Length") + // If the request is an OPTIONS request, handle it and stop further processing if r.Method == http.MethodOptions { w.Header().Set("Access-Control-Max-Age", "86400") w.WriteHeader(http.StatusOK) - return + return true } + + // Continue processing other requests w.Header().Set("Content-Type", "application/json") + return false } diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c852a26..c2b43e3 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -18,7 +18,10 @@ import ( // RateLimiter middleware to limit requests based on a specified limit and duration func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -79,7 +82,10 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - enableCors(w, r) + if handleCors(w, r) { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/server/http.go b/internal/server/http.go index b95eb63..8ebf97a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -42,6 +42,7 @@ func errorResponse(response string) string { slog.Error("Error marshaling response: %v", slog.Any("err", err)) return `{"error": "internal server error"}` } + return string(jsonResponse) } @@ -117,7 +118,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { - http.Error(w, errorResponse("error executing command"), http.StatusBadRequest) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } diff --git a/main.go b/main.go index b01de76..99c98e3 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "net/http" + "os" "server/config" "server/internal/db" "server/internal/server" @@ -14,6 +15,7 @@ func main() { diceClient, err := db.InitDiceClient(configValue) if err != nil { slog.Error("Failed to initialize DiceDB client: %v", slog.Any("err", err)) + os.Exit(1) } // Create mux and register routes diff --git a/util/cmds/cmds.go b/util/cmds/cmds.go index bb7a275..1e68d08 100644 --- a/util/cmds/cmds.go +++ b/util/cmds/cmds.go @@ -1,6 +1,6 @@ package cmds type CommandRequest struct { - Cmd string `json:"cmd"` - Args []string + Cmd string `json:"cmd"` + Args []string `json:"args"` } diff --git a/util/helpers.go b/util/helpers.go index 20b3c2b..20a605a 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,35 +10,10 @@ import ( "net/http/httptest" "server/internal/middleware" db "server/internal/tests/dbmocks" - cmds "server/util/cmds" + "server/util/cmds" "strings" ) -const ( - Key = "key" - Keys = "keys" - KeyPrefix = "key_prefix" - Field = "field" - Path = "path" - Value = "value" - Values = "values" - User = "user" - Password = "password" - Seconds = "seconds" - KeyValues = "key_values" - True = "true" - QwatchQuery = "query" - Offset = "offset" - Member = "member" - Members = "members" - - JSONIngest string = "JSON.INGEST" -) - -var priorityKeys = []string{ - Key, Keys, Field, Path, Value, Values, Seconds, User, Password, KeyValues, QwatchQuery, Offset, Member, Members, -} - // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -46,7 +21,7 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } - args, err := extractArgsFromRequest(r, command) + args, err := newExtractor(r) if err != nil { return nil, err } @@ -62,29 +37,9 @@ func extractCommand(path string) string { return strings.ToUpper(command) } -func extractArgsFromRequest(r *http.Request, command string) ([]string, error) { +func newExtractor(r *http.Request) ([]string, error) { var args []string - queryParams := r.URL.Query() - keyPrefix := queryParams.Get(KeyPrefix) - - if keyPrefix != "" && command == JSONIngest { - args = append(args, keyPrefix) - } - - if r.Body != nil { - bodyArgs, err := parseRequestBody(r.Body) - if err != nil { - return nil, err - } - args = append(args, bodyArgs...) - } - - return args, nil -} - -func parseRequestBody(body io.ReadCloser) ([]string, error) { - var args []string - bodyContent, err := io.ReadAll(body) + bodyContent, err := io.ReadAll(r.Body) if err != nil { return nil, err } @@ -93,7 +48,7 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return args, nil } - var jsonBody map[string]interface{} + var jsonBody []interface{} if err := json.Unmarshal(bodyContent, &jsonBody); err != nil { return nil, err } @@ -102,63 +57,16 @@ func parseRequestBody(body io.ReadCloser) ([]string, error) { return nil, fmt.Errorf("empty JSON object") } - args = append(args, extractPriorityArgs(jsonBody)...) - args = append(args, extractRemainingArgs(jsonBody)...) - - return args, nil -} - -func extractPriorityArgs(jsonBody map[string]interface{}) []string { - var args []string - for _, key := range priorityKeys { - if val, exists := jsonBody[key]; exists { - switch key { - case Keys, Values, Members: - args = append(args, convertListToStrings(val.([]interface{}))...) - case KeyValues: - args = append(args, convertMapToStrings(val.(map[string]interface{}))...) - default: - args = append(args, fmt.Sprintf("%v", val)) - } - delete(jsonBody, key) - } - } - return args -} - -func extractRemainingArgs(jsonBody map[string]interface{}) []string { - var args []string - for key, val := range jsonBody { - switch v := val.(type) { - case string: - args = append(args, key) - if !strings.EqualFold(v, True) { - args = append(args, v) - } - case map[string]interface{}, []interface{}: - jsonValue, _ := json.Marshal(v) - args = append(args, string(jsonValue)) - default: - args = append(args, key, fmt.Sprintf("%v", v)) + for _, val := range jsonBody { + s, ok := val.(string) + if !ok { + return nil, fmt.Errorf("invalid input") } - } - return args -} -func convertListToStrings(list []interface{}) []string { - var result []string - for _, v := range list { - result = append(result, fmt.Sprintf("%v", v)) + args = append(args, s) } - return result -} -func convertMapToStrings(m map[string]interface{}) []string { - var result []string - for k, v := range m { - result = append(result, k, fmt.Sprintf("%v", v)) - } - return result + return args, nil } // JSONResponse sends a JSON response to the client From 99ce7fe03eaf13bbff6dc67ccc872b7cb0f43713 Mon Sep 17 00:00:00 2001 From: Prashant Shubham Date: Sat, 5 Oct 2024 23:34:06 +0530 Subject: [PATCH 06/15] Adding support for generic command execution (#27) --- internal/db/dicedb.go | 33 ++++++++++++++++++++++++++++++--- internal/server/http.go | 1 + 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/internal/db/dicedb.go b/internal/db/dicedb.go index 71d4901..7da8f9f 100644 --- a/internal/db/dicedb.go +++ b/internal/db/dicedb.go @@ -12,11 +12,14 @@ import ( "os" "server/config" "server/util/cmds" + "strings" "time" dicedb "github.com/dicedb/go-dice" ) +const RespNil = "(nil)" + type DiceDB struct { Client *dicedb.Client Ctx context.Context @@ -60,7 +63,7 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err res, err := db.Client.Do(db.Ctx, args...).Result() if errors.Is(err, dicedb.Nil) { - return nil, errors.New("(nil)") + return RespNil, nil } if err != nil { @@ -74,13 +77,37 @@ func (db *DiceDB) ExecuteCommand(command *cmds.CommandRequest) (interface{}, err case []byte: return string(v), nil case []interface{}: + return renderListResponse(v) case int64: return fmt.Sprintf("%v", v), nil case nil: - return "(nil)", nil + return RespNil, nil default: return fmt.Sprintf("%v", v), nil } +} + +func renderListResponse(items []interface{}) (string, error) { + if len(items)%2 != 0 { + return "", fmt.Errorf("(error) invalid result format") + } + + var builder strings.Builder + for i := 0; i < len(items); i += 2 { + field, ok1 := items[i].(string) + value, ok2 := items[i+1].(string) + + // Check if both field and value are valid strings + if !ok1 || !ok2 { + return "", fmt.Errorf("(error) invalid result type") + } + + // Append the formatted field and value + _, err := fmt.Fprintf(&builder, "%d) \"%s\"\n%d) \"%s\"\n", i+1, field, i+2, value) + if err != nil { + return "", err + } + } - return nil, fmt.Errorf("(error) invalid result type") + return builder.String(), nil } diff --git a/internal/server/http.go b/internal/server/http.go index 8ebf97a..a9e8c2e 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -118,6 +118,7 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { + slog.Error("error: failure in executing command", "error", slog.Any("err", err)) http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return } From 6c7383f8e6b86209045127a8ad5001b27a692742 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 18:40:45 +0530 Subject: [PATCH 07/15] conflicts-resolved-2 --- go.sum | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.sum b/go.sum index 40633dc..36368eb 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= -github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgN9itIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= From ec40c243f1071b766311fa0f283e79c0bfe619e8 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 18:50:30 +0530 Subject: [PATCH 08/15] name-changed --- internal/server/http.go | 12 ++++++------ pkg/util/blacklist.go | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index 77e3e4f..ebc0637 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -110,12 +110,12 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } - // Check if the command is blacklisted - if err := util.IsBlacklistedCommand(diceCmd.Cmd); err != nil { - // Return the error message in the specified format - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) - return - } +// Check if the command is blocklisted +if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return +} + resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { diff --git a/pkg/util/blacklist.go b/pkg/util/blacklist.go index 0a84628..7c7b19d 100644 --- a/pkg/util/blacklist.go +++ b/pkg/util/blacklist.go @@ -11,11 +11,11 @@ var blacklistedCommands = []string{ "LATENCY", "CLIENT", "SLEEP", "PERSIST", } -// IsBlacklistedCommand checks if a command is blacklisted -func IsBlacklistedCommand(cmd string) error { +// BlockListedCommand checks if a command is blocklisted +func BlockListedCommand(cmd string) error { for _, blacklistedCmd := range blacklistedCommands { if strings.ToUpper(cmd) == blacklistedCmd { - return errors.New("command is blacklisted") + return errors.New("command is blocklisted") } } return nil From dffb2329e0aadc55bd7cfac62f3a402f5d56a140 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 18:58:22 +0530 Subject: [PATCH 09/15] bool-applied --- pkg/util/blacklist.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pkg/util/blacklist.go b/pkg/util/blacklist.go index 7c7b19d..ad72a62 100644 --- a/pkg/util/blacklist.go +++ b/pkg/util/blacklist.go @@ -5,18 +5,17 @@ import ( "strings" ) -var blacklistedCommands = []string{ - "FLUSHALL", "FLUSHDB", "DUMP", "ABORT", "AUTH", "CONFIG", "SAVE", "BGSAVE", - "BGREWRITEAOF", "RESTORE", "MULTI", "EXEC", "DISCARD", "QWATCH", "QUNWATCH", - "LATENCY", "CLIENT", "SLEEP", "PERSIST", +var blocklistedCommands = map[string]bool{ + "FLUSHALL": true, "FLUSHDB": true, "DUMP": true, "ABORT": true, "AUTH": true, + "CONFIG": true, "SAVE": true, "BGSAVE": true, "BGREWRITEAOF": true, "RESTORE": true, + "MULTI": true, "EXEC": true, "DISCARD": true, "QWATCH": true, "QUNWATCH": true, + "LATENCY": true, "CLIENT": true, "SLEEP": true, "PERSIST": true, } // BlockListedCommand checks if a command is blocklisted func BlockListedCommand(cmd string) error { - for _, blacklistedCmd := range blacklistedCommands { - if strings.ToUpper(cmd) == blacklistedCmd { - return errors.New("command is blocklisted") - } + if blocklistedCommands[strings.ToUpper(cmd)] { + return errors.New("command is blocklisted") } return nil } From 6ed4d065e602ba1d8a4fa16ccb09b0fc143cb383 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 19:13:32 +0530 Subject: [PATCH 10/15] checks-fixed --- internal/server/http.go | 1 - pkg/util/blacklist.go | 21 --------------------- util/helpers.go | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 22 deletions(-) delete mode 100644 pkg/util/blacklist.go diff --git a/internal/server/http.go b/internal/server/http.go index ebc0637..f4ece8a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -116,7 +116,6 @@ if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { return } - resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { slog.Error("error: failure in executing command", "error", slog.Any("err", err)) diff --git a/pkg/util/blacklist.go b/pkg/util/blacklist.go deleted file mode 100644 index ad72a62..0000000 --- a/pkg/util/blacklist.go +++ /dev/null @@ -1,21 +0,0 @@ -package helpers - -import ( - "errors" - "strings" -) - -var blocklistedCommands = map[string]bool{ - "FLUSHALL": true, "FLUSHDB": true, "DUMP": true, "ABORT": true, "AUTH": true, - "CONFIG": true, "SAVE": true, "BGSAVE": true, "BGREWRITEAOF": true, "RESTORE": true, - "MULTI": true, "EXEC": true, "DISCARD": true, "QWATCH": true, "QUNWATCH": true, - "LATENCY": true, "CLIENT": true, "SLEEP": true, "PERSIST": true, -} - -// BlockListedCommand checks if a command is blocklisted -func BlockListedCommand(cmd string) error { - if blocklistedCommands[strings.ToUpper(cmd)] { - return errors.New("command is blocklisted") - } - return nil -} diff --git a/util/helpers.go b/util/helpers.go index 20a605a..141ebcf 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -14,6 +14,37 @@ import ( "strings" ) +// Map of blocklisted commands +var blocklistedCommands = map[string]bool{ + "FLUSHALL": true, + "FLUSHDB": true, + "DUMP": true, + "ABORT": true, + "AUTH": true, + "CONFIG": true, + "SAVE": true, + "BGSAVE": true, + "BGREWRITEAOF": true, + "RESTORE": true, + "MULTI": true, + "EXEC": true, + "DISCARD": true, + "QWATCH": true, + "QUNWATCH": true, + "LATENCY": true, + "CLIENT": true, + "SLEEP": true, + "PERSIST": true, +} + +// BlockListedCommand checks if a command is blocklisted +func BlockListedCommand(cmd string) error { + if _, exists := blocklistedCommands[strings.ToUpper(cmd)]; exists { + return errors.New("command is blocklisted") + } + return nil +} + // ParseHTTPRequest parses an incoming HTTP request and converts it into a CommandRequest for Redis commands func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { command := extractCommand(r.URL.Path) @@ -21,6 +52,11 @@ func ParseHTTPRequest(r *http.Request) (*cmds.CommandRequest, error) { return nil, errors.New("invalid command") } + // Check if the command is blocklisted + if err := BlockListedCommand(command); err != nil { + return nil, err + } + args, err := newExtractor(r) if err != nil { return nil, err From 9877cbd8f2f1ad1c645ff1fc44a7f1ec83868a42 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 19:14:55 +0530 Subject: [PATCH 11/15] comments-removed --- internal/server/http.go | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index f4ece8a..3a08f4a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" // Added to format error messages + "fmt" "log/slog" "net/http" "strings" @@ -21,8 +21,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -// HandlerMux wraps ServeMux and forces REST paths to lowercase -// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -48,10 +46,8 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -110,11 +106,10 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } -// Check if the command is blocklisted -if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) - return -} + if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return + } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { From 66bea7ef261cb3df1e2305cdf43aac80f96b1dc2 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 19:17:26 +0530 Subject: [PATCH 12/15] comments --- internal/server/http.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index 3a08f4a..191c83a 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "errors" - "fmt" + "fmt" "log/slog" "net/http" "strings" @@ -21,6 +21,8 @@ type HTTPServer struct { DiceClient *db.DiceDB } +// HandlerMux wraps ServeMux and forces REST paths to lowercase +// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -46,8 +48,10 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -106,10 +110,11 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } - if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) - return - } +// Check if the command is blocklisted +if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return +} resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { From 97a84d9a66352a9972a446c89da315b801e94566 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 19:27:47 +0530 Subject: [PATCH 13/15] error_updated --- util/helpers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/helpers.go b/util/helpers.go index 141ebcf..1ed8890 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -40,7 +40,7 @@ var blocklistedCommands = map[string]bool{ // BlockListedCommand checks if a command is blocklisted func BlockListedCommand(cmd string) error { if _, exists := blocklistedCommands[strings.ToUpper(cmd)]; exists { - return errors.New("command is blocklisted") + return errors.New("ERR unknown command '" + cmd + "'") } return nil } From 76464b3e7628f9408f6f2756b8271c58cc9fdb70 Mon Sep 17 00:00:00 2001 From: Yash Budhia Date: Sun, 6 Oct 2024 19:59:20 +0530 Subject: [PATCH 14/15] minors --- internal/server/http.go | 20 ++++++++------------ util/helpers.go | 30 +++++++++++++++--------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index 191c83a..9dab2c5 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,15 +4,15 @@ import ( "context" "encoding/json" "errors" - "fmt" + "fmt" "log/slog" "net/http" "strings" "sync" "time" - "server/internal/middleware" "server/internal/db" + "server/internal/middleware" util "server/util" ) @@ -21,8 +21,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -// HandlerMux wraps ServeMux and forces REST paths to lowercase -// and attaches a rate limiter with the handler type HandlerMux struct { mux *http.ServeMux rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) @@ -48,10 +46,8 @@ func errorResponse(response string) string { } func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Convert the path to lowercase before passing to the underlying mux. middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) } @@ -110,11 +106,11 @@ func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { return } -// Check if the command is blocklisted -if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) - return -} + // Check if the command is blocklisted + if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { + http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + return + } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { @@ -126,7 +122,7 @@ if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { respStr, ok := resp.(string) if !ok { slog.Error("error: response is not a string", "error", slog.Any("err", err)) - http.Error(w, errorResponse("internal Server Error"), http.StatusInternalServerError) + http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) return } diff --git a/util/helpers.go b/util/helpers.go index 1ed8890..3e3aea8 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -16,24 +16,24 @@ import ( // Map of blocklisted commands var blocklistedCommands = map[string]bool{ - "FLUSHALL": true, - "FLUSHDB": true, - "DUMP": true, - "ABORT": true, + "FLUSHALL": true, + "FLUSHDB": true, + "DUMP": true, + "ABORT": true, "AUTH": true, - "CONFIG": true, - "SAVE": true, - "BGSAVE": true, - "BGREWRITEAOF": true, + "CONFIG": true, + "SAVE": true, + "BGSAVE": true, + "BGREWRITEAOF": true, "RESTORE": true, - "MULTI": true, - "EXEC": true, - "DISCARD": true, - "QWATCH": true, + "MULTI": true, + "EXEC": true, + "DISCARD": true, + "QWATCH": true, "QUNWATCH": true, - "LATENCY": true, - "CLIENT": true, - "SLEEP": true, + "LATENCY": true, + "CLIENT": true, + "SLEEP": true, "PERSIST": true, } From 6d27e8f7a18c9b7969575fe82a1671c3b44af3be Mon Sep 17 00:00:00 2001 From: pshubham Date: Sun, 6 Oct 2024 20:13:08 +0530 Subject: [PATCH 15/15] Addressing review comments --- internal/server/http.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/internal/server/http.go b/internal/server/http.go index 9dab2c5..876f959 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "log/slog" "net/http" "strings" @@ -102,13 +101,7 @@ func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse("error parsing http request"), http.StatusBadRequest) - return - } - - // Check if the command is blocklisted - if err := util.BlockListedCommand(diceCmd.Cmd); err != nil { - http.Error(w, errorResponse(fmt.Sprintf("ERR unknown command '%s'", diceCmd.Cmd)), http.StatusForbidden) + http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) return }