Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rate limit 100 requests / 10mins #14

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/chat_message_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,4 @@ func (h *ChatMessageHandler) DeleteChatMessagesBySesionUUID(w http.ResponseWrite
return
}
w.WriteHeader(http.StatusOK)
}
}
12 changes: 10 additions & 2 deletions api/chat_message_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,19 @@ func (s *ChatMessageService) GetLastNChatMessages(ctx context.Context, uuid stri
return message, nil
}

//DeleteChatMessagesBySesionUUID deletes chat messages by session uuid.
// DeleteChatMessagesBySesionUUID deletes chat messages by session uuid.
func (s *ChatMessageService) DeleteChatMessagesBySesionUUID(ctx context.Context, uuid string) error {
err := s.q.DeleteChatMessagesBySesionUUID(ctx, uuid)
if err != nil {
return errors.New("failed to delete message")
}
return nil
}
}

func (s *ChatMessageService) GetChatMessagesCount(ctx context.Context, userID int32) (int32, error) {
count, err := s.q.GetChatMessagesCount(ctx, userID)
if err != nil {
return 0, err
}
return int32(count), nil
}
6 changes: 4 additions & 2 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func main() {
}
OPENAI_API_KEY = os.Getenv("OPENAI_API_KEY")


if JWT_SECRET, exists = os.LookupEnv("JWT_SECRET"); !exists {
log.Fatal("JWT_SECRET not set")
}
Expand All @@ -36,7 +35,7 @@ func main() {
if JWT_AUD, exists = os.LookupEnv("JWT_AUD"); !exists {
log.Fatal("JWT_AUD not set")
}
JWT_AUD= os.Getenv("JWT_AUD")
JWT_AUD = os.Getenv("JWT_AUD")

// Create a new logger instance, configure it as desired
logger = log.New()
Expand Down Expand Up @@ -165,6 +164,9 @@ func main() {
router.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
router.Use(IsAuthorizedMiddleware)
// Wrap the router with the logging middleware
// 10 min < 100 requests
limitedRouter := RateLimitByUserID(sqlc_q)
router.Use(limitedRouter)
// loggedMux := loggingMiddleware(router, logger)
loggedRouter := handlers.LoggingHandler(logger.Out, router)
err = http.ListenAndServe(":8077", loggedRouter)
Expand Down
1 change: 1 addition & 0 deletions api/middleware_authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func IsAuthorizedMiddleware(handler http.Handler) http.Handler {
}
ctx := context.WithValue(r.Context(), userContextKey, userID)
ctx = context.WithValue(ctx, roleContextKey, role)

// TODO: get trace id and add it to context
//traceID := r.Header.Get("X-Request-Id")
//if len(traceID) > 0 {
Expand Down
56 changes: 0 additions & 56 deletions api/middleware_log.go

This file was deleted.

37 changes: 37 additions & 0 deletions api/middleware_rateLimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

import (
"net/http"
"strconv"

"github.com/swuecho/chatgpt_backend/sqlc_queries"
)

// This function returns a middleware that limits requests from each user by their ID.
func RateLimitByUserID(q *sqlc_queries.Queries) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the user ID from the request, e.g. from a JWT token.
ctx := r.Context()
userIDStr := ctx.Value(userContextKey).(string)
userIDInt, err := strconv.Atoi(userIDStr)
if err != nil {
http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest)
return
}
messageCount, err := q.GetChatMessagesCount(r.Context(), int32(userIDInt))
if err != nil {
http.Error(w, "Error: Could not get message count.", http.StatusInternalServerError)
return
}

// Check if the request exceeds the rate limit.
if messageCount > 100 {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
// Call the next handler.
next.ServeHTTP(w, r)
})
}
}
10 changes: 9 additions & 1 deletion api/sqlc/queries/chat_message.sql
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,12 @@ WHERE uuid = $1 ;

-- name: DeleteChatMessagesBySesionUUID :exec
DELETE FROM chat_message
WHERE chat_session_uuid = $1;
WHERE chat_session_uuid = $1;


-- name: GetChatMessagesCount :one
-- Get total chat message count for user in last 10 minutes
SELECT COUNT(*)
FROM chat_message
WHERE user_id = $1
AND created_at >= NOW() - INTERVAL '10 minutes';
15 changes: 15 additions & 0 deletions api/sqlc_queries/chat_message.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions chat.code-workspace
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"folders": [
{
"path": "web"
},
{
"path": "api"
},
{
"path": "e2e"
}
],
"settings": {}
}