diff --git a/internal/server/http.go b/internal/server/http.go index 876f959..3aa5def 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -1,8 +1,9 @@ +//modified http + package server import ( "context" - "encoding/json" "errors" "log/slog" "net/http" @@ -25,24 +26,6 @@ type HandlerMux struct { rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) } -type HTTPResponse struct { - Data interface{} `json:"data"` -} - -type HTTPErrorResponse struct { - Error interface{} `json:"error"` -} - -func errorResponse(response string) string { - errorMessage := map[string]string{"error": response} - jsonResponse, err := json.Marshal(errorMessage) - if err != nil { - slog.Error("Error marshaling response: %v", slog.Any("err", err)) - return `{"error": "internal server error"}` - } - - return string(jsonResponse) -} func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -95,43 +78,23 @@ func (s *HTTPServer) Shutdown() error { } func (s *HTTPServer) HealthCheck(w http.ResponseWriter, request *http.Request) { - util.JSONResponse(w, http.StatusOK, map[string]string{"message": "server is running"}) + util.HttpResponseJSON(w, http.StatusOK, map[string]string{"message": "Server is running"}) } func (s *HTTPServer) CliHandler(w http.ResponseWriter, r *http.Request) { diceCmd, err := util.ParseHTTPRequest(r) if err != nil { - http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) + util.HttpResponseException(w,http.StatusBadRequest,err) return } resp, err := s.DiceClient.ExecuteCommand(diceCmd) if err != nil { slog.Error("error: failure in executing command", "error", slog.Any("err", err)) - http.Error(w, errorResponse(err.Error()), http.StatusBadRequest) - return - } - - respStr, ok := resp.(string) - if !ok { - slog.Error("error: response is not a string", "error", slog.Any("err", err)) - http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) - return - } - - httpResponse := HTTPResponse{Data: respStr} - responseJSON, err := json.Marshal(httpResponse) - if err != nil { - slog.Error("error marshaling response to json", "error", slog.Any("err", err)) - http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) - return - } - - _, err = w.Write(responseJSON) - if err != nil { - http.Error(w, errorResponse("internal server error"), http.StatusInternalServerError) + util.HttpResponseException(w,http.StatusBadRequest,err) return } + util.HttpResponseJSON(w, http.StatusOK, resp) } func (s *HTTPServer) SearchHandler(w http.ResponseWriter, request *http.Request) { diff --git a/util/helpers.go b/util/helpers.go index 3ce7f74..25b4793 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -1,3 +1,5 @@ +// modifed helpers.go + package utils import ( @@ -8,6 +10,8 @@ import ( "log/slog" "net/http" "net/http/httptest" + "os" + "runtime/debug" "server/config" "server/internal/middleware" db "server/internal/tests/dbmocks" @@ -15,6 +19,19 @@ import ( "strings" ) +type HttpResponse struct { + Data interface{} `json:"data"` + Error *ErrorDetails `json:"error"` + HasError bool `json:"hasError"` + HasData bool `json:"hasData"` + StackTrace *string `json:"stackTrace,omitempty"` +} + +type ErrorDetails struct { + Message *string `json:"message"` + StackTrace *string `json:"stackTrace,omitempty"` +} + // Map of blocklisted commands var blocklistedCommands = map[string]bool{ "FLUSHALL": true, @@ -131,3 +148,47 @@ func SetupRateLimiter(limit int64, window float64) (*httptest.ResponseRecorder, return w, r, rateLimiter } + +func generateHttpResponse(w http.ResponseWriter, statusCode int, data interface{}, err *string) { + response := HttpResponse{ + HasData: data != nil, + HasError: err != nil, + Data: data, + } + + if err != nil { + errorDetails := &ErrorDetails{ + Message: err, + } + if os.Getenv("ENV") == "development" { + stackTrace := string(debug.Stack()) + errorDetails.StackTrace = &stackTrace + } + response.Error = errorDetails + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + if encodeErr := json.NewEncoder(w).Encode(response); encodeErr != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } +} + +func HttpResponseJSON(w http.ResponseWriter,statusCode int, data interface{}) { + generateHttpResponse(w, http.StatusOK, data, nil) +} + +func HttpResponseException(w http.ResponseWriter, statusCode int, err interface{}) { + var errorStr string + switch e := err.(type) { + case error: + errorStr = e.Error() + case string: + errorStr = e + default: + errorStr = "Unknown error type" + } + generateHttpResponse(w, statusCode, nil, &errorStr) +} +