From 6c859de188183bbd288a914abf1824e590bacafe Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:01:02 +0800 Subject: [PATCH 01/10] update --- api/chat_main_handler.go | 72 ++++------------------------------------ api/chat_main_service.go | 42 +++++++++++++++++++++++ 2 files changed, 48 insertions(+), 66 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index daa21e3a..b6d70e11 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -13,7 +13,6 @@ import ( "strings" uuid "github.com/iris-contrib/go.uuid" - "github.com/samber/lo" openai "github.com/sashabaranov/go-openai" "github.com/swuecho/chatgpt_backend/sqlc_queries" @@ -61,8 +60,8 @@ type ChatCompletionResponse struct { type Choice struct { Message openai.ChatCompletionMessage `json:"message"` - FinishReason interface{} `json:"finish_reason"` - Index int `json:"index"` + FinishReason interface{} `json:"finish_reason"` + Index int `json:"index"` } type Message struct { @@ -71,12 +70,11 @@ type Message struct { } type OpenaiChatRequest struct { - Model string `json:"model"` + Model string `json:"model"` Messages []openai.ChatCompletionMessage `json:"messages"` - } -func NewUserMessage(content string) openai.ChatCompletionMessage{ +func NewUserMessage(content string) openai.ChatCompletionMessage { return openai.ChatCompletionMessage{Role: "user", Content: content} } @@ -148,7 +146,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } // Send the response as JSON - chatCompletionMessages, err := getAskMessages(h, ctx, w, chat_session, chatUuid, true) + chatCompletionMessages, err := h.chatService.getAskMessages(chat_session, chatUuid, true) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Get chat message error", err) return @@ -176,22 +174,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr return } - //// - - // no session exists - // - // if no session chat_created, create new chat_session with $uuid - // create a new prompt with topic = $uuid, role = "system", content= req.Prompt - - // if session avaiable, - // GetChatPromptBySessionID and create Message from Prompt - // GetLatestMessagesBySessionID and create Messsage(s) from messages - - // Check if the chat session exists - - // no session exists - // create session and prompt - chatSession, err := h.chatService.q.GetChatSessionByUUID(ctx, chatSessionUuid) if err != nil { @@ -245,7 +227,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr log.Printf("%+v\n", chatPrompt) } - msgs, err := getAskMessages(h, ctx, w, chatSession, chatUuid, false) + msgs, err := h.chatService.getAskMessages(chatSession, chatUuid, false) if err != nil { RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to collect messages: %w", err).Error(), err) return @@ -273,8 +255,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr UpdatedBy: int32(userIDInt), Raw: json.RawMessage([]byte("{}")), } - log.Println(chatMessageParams) - m, err := h.chatService.q.CreateChatMessage(ctx, chatMessageParams) log.Println(m) @@ -310,46 +290,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } -func getAskMessages(h *ChatHandler, ctx context.Context, w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) { - chatSessionUuid := chatSession.Uuid - - lastN := chatSession.MaxLength - if chatSession.MaxLength == 0 { - lastN = 10 - } - - chat_prompts, err := h.chatService.q.GetChatPromptsBySessionUUID(ctx, chatSessionUuid) - - if err != nil { - return nil, fmt.Errorf("fail to get prompt: %w", err) - } - - var chat_massages []sqlc_queries.ChatMessage - if regenerate { - chat_massages, err = h.chatService.q.GetLastNChatMessages(ctx, - sqlc_queries.GetLastNChatMessagesParams{ - Uuid: chatUuid, - Limit: lastN, - }) - - } else { - chat_massages, err = h.chatService.q.GetLatestMessagesBySessionUUID(ctx, - sqlc_queries.GetLatestMessagesBySessionUUIDParams{ChatSessionUuid: chatSession.Uuid, Limit: lastN}) - } - - if err != nil { - return nil, fmt.Errorf("fail to get messages: %w", err) - } - chat_prompt_msgs := lo.Map(chat_prompts, func(m sqlc_queries.ChatPrompt, _ int) openai.ChatCompletionMessage { - return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content} - }) - chat_message_msgs := lo.Map(chat_massages, func(m sqlc_queries.ChatMessage, _ int) openai.ChatCompletionMessage { - return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content} - }) - msgs := append(chat_prompt_msgs, chat_message_msgs...) - return msgs, nil -} - func chat_stream(ctx context.Context, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage, w http.ResponseWriter) (string, string, bool) { apiKey := appConfig.OPENAI.API_KEY diff --git a/api/chat_main_service.go b/api/chat_main_service.go index 60d888c4..e6364a0b 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -235,3 +235,45 @@ func GetAiAnswerOpenApi(msgs []openai.ChatCompletionMessage) (ChatCompletionResp } return aiAnswer, nil } + + +func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) { + ctx := context.Background() + chatSessionUuid := chatSession.Uuid + + lastN := chatSession.MaxLength + if chatSession.MaxLength == 0 { + lastN = 10 + } + + chat_prompts, err := s.q.GetChatPromptsBySessionUUID(ctx, chatSessionUuid) + + if err != nil { + return nil, fmt.Errorf("fail to get prompt: %w", err) + } + + var chat_massages []sqlc_queries.ChatMessage + if regenerate { + chat_massages, err = s.q.GetLastNChatMessages(ctx, + sqlc_queries.GetLastNChatMessagesParams{ + Uuid: chatUuid, + Limit: lastN, + }) + + } else { + chat_massages, err = s.q.GetLatestMessagesBySessionUUID(ctx, + sqlc_queries.GetLatestMessagesBySessionUUIDParams{ChatSessionUuid: chatSession.Uuid, Limit: lastN}) + } + + if err != nil { + return nil, fmt.Errorf("fail to get messages: %w", err) + } + chat_prompt_msgs := lo.Map(chat_prompts, func(m sqlc_queries.ChatPrompt, _ int) openai.ChatCompletionMessage { + return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content} + }) + chat_message_msgs := lo.Map(chat_massages, func(m sqlc_queries.ChatMessage, _ int) openai.ChatCompletionMessage { + return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content} + }) + msgs := append(chat_prompt_msgs, chat_message_msgs...) + return msgs, nil +} \ No newline at end of file From 5f09266049d4e1eef180e8168f4be7d6b59978ea Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:13:57 +0800 Subject: [PATCH 02/10] better --- api/chat_main_handler.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index b6d70e11..ee728b7f 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -156,7 +156,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr if chat_session.Debug { log.Printf("%+v\n", chatCompletionMessages) } - answerText, _, shouldReturn := chat_stream(ctx, chat_session, chatCompletionMessages, w) + answerText, _, shouldReturn := chat_stream(w, chat_session, chatCompletionMessages) if shouldReturn { return } @@ -265,7 +265,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } else { // Set up SSE headers - answerText, answerID, shouldReturn := chat_stream(ctx, chatSession, msgs, w) + answerText, answerID, shouldReturn := chat_stream(w, chatSession, msgs) if shouldReturn { return } @@ -290,7 +290,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } -func chat_stream(ctx context.Context, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage, w http.ResponseWriter) (string, string, bool) { +func chat_stream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage) (string, string, bool) { apiKey := appConfig.OPENAI.API_KEY client := openai.NewClient(apiKey) @@ -306,6 +306,7 @@ func chat_stream(ctx context.Context, chatSession sqlc_queries.ChatSession, chat // N: n, Stream: true, } + ctx := context.Background() stream, err := client.CreateChatCompletionStream(ctx, openai_req) if err != nil { RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("CompletionStream error: %v", err), nil) From 52b1d111c71548059229bfaf87221a9fe105f7f7 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:28:24 +0800 Subject: [PATCH 03/10] update --- api/chat_main_handler.go | 28 ++++++++++++---------------- api/chat_main_service.go | 28 ++++++++++++++++++++++++++-- api/errors.go | 2 ++ 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index ee728b7f..90d74ef7 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -83,30 +83,26 @@ func (h *ChatHandler) chatHandler(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Invalid request body: %v", err) + json.NewEncoder(w).Encode(map[string]interface{}{"error": "Invalid request body"}) return } + defer r.Body.Close() - chatSessionUuid := req.SessionUuid - chatUuid := req.ChatUuid - newQuestion := req.Prompt - log.Printf("Received prompt: %s\n", newQuestion) - ctx := r.Context() - userIDStr := ctx.Value(userContextKey).(string) - userIDInt, err := strconv.Atoi(userIDStr) - if err != nil { - http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest) - return - } - answer_msg, err := h.chatService.Chat(chatSessionUuid, chatUuid, newQuestion, int32(userIDInt)) + answerMsg, err := h.chatService.chatServiceX(r.Context(), &req) if err != nil { - fmt.Fprintf(w, "problem in chat: %v", err) + statusCode := http.StatusInternalServerError + if errors.Is(err, ErrInvalidUserID) { + statusCode = http.StatusBadRequest + } else { + statusCode = http.StatusNotFound + } + w.WriteHeader(statusCode) + RespondWithError(w, statusCode, err.Error(), err) return } - // Send the response as JSON w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{"status": "Success", "text": answer_msg.Content, "chatUuid": answer_msg.Uuid}) + json.NewEncoder(w).Encode(map[string]interface{}{"status": "Success", "text": answerMsg.Content, "chatUuid": answerMsg.Uuid}) } // OpenAIChatCompletionAPIWithStreamHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE) diff --git a/api/chat_main_service.go b/api/chat_main_service.go index e6364a0b..ca6f5ff1 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log" "net/http" + "strconv" "strings" uuid "github.com/iris-contrib/go.uuid" @@ -236,7 +237,6 @@ func GetAiAnswerOpenApi(msgs []openai.ChatCompletionMessage) (ChatCompletionResp return aiAnswer, nil } - func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) { ctx := context.Background() chatSessionUuid := chatSession.Uuid @@ -276,4 +276,28 @@ func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatU }) msgs := append(chat_prompt_msgs, chat_message_msgs...) return msgs, nil -} \ No newline at end of file +} + +func (s *ChatService) chatServiceX(ctx context.Context, req *ChatRequest) (*sqlc_queries.ChatMessage, error) { + chatSessionUuid := req.SessionUuid + chatUuid := req.ChatUuid + newQuestion := req.Prompt + log.Printf("Received prompt: %s\n", newQuestion) + + userIDStr, ok := ctx.Value(userContextKey).(string) + if !ok { + return nil, ErrInvalidUserID + } + + userIDInt, err := strconv.Atoi(userIDStr) + if err != nil { + return nil, ErrInvalidUserID + } + + answerMsg, err := s.Chat(chatSessionUuid, chatUuid, newQuestion, int32(userIDInt)) + if err != nil { + return nil, err + } + + return answerMsg, nil +} diff --git a/api/errors.go b/api/errors.go index 6f59363d..35da70e0 100644 --- a/api/errors.go +++ b/api/errors.go @@ -5,9 +5,11 @@ import "errors" // ErrUsageLimitExceeded is returned when the usage limit is exceeded. var ErrUsageLimitExceeded = errors.New("usage limit exceeded") + // auth related var ErrInvalidCredentials = errors.New("invalid credentials") +var ErrInvalidUserID = errors.New("invalid user id") /// token related From e20f47a3e032df17a39cb35edf49d873bd6511ca Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:35:58 +0800 Subject: [PATCH 04/10] update --- api/chat_main_handler.go | 22 +++++++++++++--------- api/chat_main_service.go | 24 ------------------------ 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 90d74ef7..3ee6433c 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -87,17 +87,21 @@ func (h *ChatHandler) chatHandler(w http.ResponseWriter, r *http.Request) { return } defer r.Body.Close() + ctx := r.Context() + userIDStr, ok := ctx.Value(userContextKey).(string) + if !ok { + RespondWithError(w, http.StatusInternalServerError, err.Error(), err) + } - answerMsg, err := h.chatService.chatServiceX(r.Context(), &req) + userIDInt, err := strconv.Atoi(userIDStr) if err != nil { - statusCode := http.StatusInternalServerError - if errors.Is(err, ErrInvalidUserID) { - statusCode = http.StatusBadRequest - } else { - statusCode = http.StatusNotFound - } - w.WriteHeader(statusCode) - RespondWithError(w, statusCode, err.Error(), err) + RespondWithError(w, http.StatusInternalServerError, err.Error(), err) + return + } + answerMsg, err := h.chatService.Chat(req.SessionUuid, req.ChatUuid, req.Prompt, int32(userIDInt)) + + if err != nil { + RespondWithError(w, http.StatusInternalServerError, err.Error(), err) return } diff --git a/api/chat_main_service.go b/api/chat_main_service.go index ca6f5ff1..ec70624d 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -9,7 +9,6 @@ import ( "fmt" "log" "net/http" - "strconv" "strings" uuid "github.com/iris-contrib/go.uuid" @@ -278,26 +277,3 @@ func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatU return msgs, nil } -func (s *ChatService) chatServiceX(ctx context.Context, req *ChatRequest) (*sqlc_queries.ChatMessage, error) { - chatSessionUuid := req.SessionUuid - chatUuid := req.ChatUuid - newQuestion := req.Prompt - log.Printf("Received prompt: %s\n", newQuestion) - - userIDStr, ok := ctx.Value(userContextKey).(string) - if !ok { - return nil, ErrInvalidUserID - } - - userIDInt, err := strconv.Atoi(userIDStr) - if err != nil { - return nil, ErrInvalidUserID - } - - answerMsg, err := s.Chat(chatSessionUuid, chatUuid, newQuestion, int32(userIDInt)) - if err != nil { - return nil, err - } - - return answerMsg, nil -} From f304b62e30739cb95ee2c0c4a5804e83375072be Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:39:45 +0800 Subject: [PATCH 05/10] update --- api/chat_main_handler.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 3ee6433c..8788764b 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -91,13 +91,10 @@ func (h *ChatHandler) chatHandler(w http.ResponseWriter, r *http.Request) { userIDStr, ok := ctx.Value(userContextKey).(string) if !ok { RespondWithError(w, http.StatusInternalServerError, err.Error(), err) - } - - userIDInt, err := strconv.Atoi(userIDStr) - if err != nil { - RespondWithError(w, http.StatusInternalServerError, err.Error(), err) return } + + userIDInt, _ := strconv.Atoi(userIDStr) answerMsg, err := h.chatService.Chat(req.SessionUuid, req.ChatUuid, req.Prompt, int32(userIDInt)) if err != nil { From c75b7872f470237517e5101683ceb9e6f4c11133 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 15:47:51 +0800 Subject: [PATCH 06/10] update --- api/chat_main_service.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/chat_main_service.go b/api/chat_main_service.go index ec70624d..bd097b87 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -10,6 +10,7 @@ import ( "log" "net/http" "strings" + "time" uuid "github.com/iris-contrib/go.uuid" "github.com/samber/lo" @@ -237,7 +238,9 @@ func GetAiAnswerOpenApi(msgs []openai.ChatCompletionMessage) (ChatCompletionResp } func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + chatSessionUuid := chatSession.Uuid lastN := chatSession.MaxLength @@ -276,4 +279,3 @@ func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatU msgs := append(chat_prompt_msgs, chat_message_msgs...) return msgs, nil } - From 53663ee7bf13d1e97b3042ef7b6ffd77653f7e92 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 16:14:02 +0800 Subject: [PATCH 07/10] update --- api/chat_main_handler.go | 59 ++++++------------------------------- api/chat_main_service.go | 15 ++++++++++ api/chat_message_service.go | 20 +++++++++++++ api/main.go | 2 +- 4 files changed, 45 insertions(+), 51 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 8788764b..89dcf0af 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -20,12 +20,14 @@ import ( ) type ChatHandler struct { - chatService *ChatService + chatService *ChatService + chatMessageService *ChatMessageService } -func NewChatHandler(chatService *ChatService) *ChatHandler { +func NewChatHandler(chatService *ChatService, chat_msg *ChatMessageService) *ChatHandler { return &ChatHandler{ - chatService: chatService, + chatService: chatService, + chatMessageService: chat_msg, } } @@ -179,7 +181,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr existingPrompt := true - log.Println(chatSessionUuid) _, err = h.chatService.q.GetOneChatPromptBySessionUUID(ctx, chatSessionUuid) if err != nil { @@ -191,33 +192,13 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } if existingPrompt { - _, err := h.chatService.q.CreateChatMessage(ctx, - sqlc_queries.CreateChatMessageParams{ - ChatSessionUuid: chatSession.Uuid, - Uuid: chatUuid, - Role: "user", - Content: newQuestion, - Raw: json.RawMessage([]byte("{}")), - UserID: userID, - CreatedBy: userID, - UpdatedBy: userID, - }) + _, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, int32(userIDInt)) if err != nil { http.Error(w, fmt.Errorf("fail to create message: %w", err).Error(), http.StatusInternalServerError) } } else { - uuidVar, _ := uuid.NewV4() - chatPrompt, err := h.chatService.q.CreateChatPrompt(ctx, - sqlc_queries.CreateChatPromptParams{ - Uuid: uuidVar.String(), - ChatSessionUuid: chatSessionUuid, - Role: "system", - Content: newQuestion, - UserID: userID, - CreatedBy: userID, - UpdatedBy: userID, - }) + chatPrompt, err := h.chatService.CreateChatPromptSimple(chatSessionUuid, newQuestion, userID) if err != nil { http.Error(w, fmt.Errorf("fail to create prompt: %w", err).Error(), http.StatusInternalServerError) } @@ -241,18 +222,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr if shouldReturn { return } - // insert ChatMessage into database - chatMessageParams := sqlc_queries.CreateChatMessageParams{ - ChatSessionUuid: chatSessionUuid, - Uuid: answerID, - Role: "assistant", - Content: answerText, - UserID: int32(userIDInt), - CreatedBy: int32(userIDInt), - UpdatedBy: int32(userIDInt), - Raw: json.RawMessage([]byte("{}")), - } - m, err := h.chatService.q.CreateChatMessage(ctx, chatMessageParams) + m, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) log.Println(m) if err != nil { @@ -267,18 +237,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr return } // insert ChatMessage into database - chatMessage := sqlc_queries.CreateChatMessageParams{ - Uuid: answerID, - ChatSessionUuid: chatSessionUuid, - Role: "assistant", - UserID: int32(userIDInt), - Content: answerText, - CreatedBy: int32(userIDInt), - UpdatedBy: int32(userIDInt), - Raw: json.RawMessage([]byte("{}")), - } - - _, err := h.chatService.q.CreateChatMessage(ctx, chatMessage) + _, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) if err != nil { RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to create message: %w", err).Error(), nil) diff --git a/api/chat_main_service.go b/api/chat_main_service.go index bd097b87..4920091e 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -279,3 +279,18 @@ func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatU msgs := append(chat_prompt_msgs, chat_message_msgs...) return msgs, nil } + +func (s *ChatService) CreateChatPromptSimple(chatSessionUuid string, newQuestion string, userID int32) (sqlc_queries.ChatPrompt, error) { + uuidVar, _ := uuid.NewV4() + chatPrompt, err := s.q.CreateChatPrompt(context.Background(), + sqlc_queries.CreateChatPromptParams{ + Uuid: uuidVar.String(), + ChatSessionUuid: chatSessionUuid, + Role: "system", + Content: newQuestion, + UserID: userID, + CreatedBy: userID, + UpdatedBy: userID, + }) + return chatPrompt, err +} diff --git a/api/chat_message_service.go b/api/chat_message_service.go index 31cd52bf..fabdd4ea 100644 --- a/api/chat_message_service.go +++ b/api/chat_message_service.go @@ -29,6 +29,26 @@ func (s *ChatMessageService) CreateChatMessage(ctx context.Context, message_para return message, nil } +// CreateChatMessage creates a new chat message. +func (s *ChatMessageService) CreateChatMessageSimple(ctx context.Context, sessionUuid, uuid, role, content string, userId int32) (sqlc_queries.ChatMessage, error) { + + chatMessage := sqlc_queries.CreateChatMessageParams{ + ChatSessionUuid: sessionUuid, + Uuid: uuid, + Role: role, + Content: content, + UserID: userId, + CreatedBy: userId, + UpdatedBy: userId, + Raw: json.RawMessage([]byte("{}")), + } + message, err := s.q.CreateChatMessage(ctx, chatMessage) + if err != nil { + return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to create message %w", err) + } + return message, nil +} + // GetChatMessageByID returns a chat message by ID. func (s *ChatMessageService) GetChatMessageByID(ctx context.Context, id int32) (sqlc_queries.ChatMessage, error) { message, err := s.q.GetChatMessageByID(ctx, id) diff --git a/api/main.go b/api/main.go index 21ec7267..a72a20f9 100644 --- a/api/main.go +++ b/api/main.go @@ -181,7 +181,7 @@ func main() { chatService := NewChatService(sqlc_q) // create a new ChatHandler instance - chatHandler := NewChatHandler(chatService) + chatHandler := NewChatHandler(chatService, chatMessageService) // regiser the ChatHandler with the router chatHandler.Register(router) From a2923e444fe86eebf0473cbb033aa4591ec5e328 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 16:15:43 +0800 Subject: [PATCH 08/10] update --- api/chat_main_handler.go | 10 ++++------ api/chat_main_service.go | 21 +++++++++++++++++++++ api/chat_message_service.go | 19 ------------------- api/main.go | 2 +- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 89dcf0af..49c0194a 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -21,13 +21,11 @@ import ( type ChatHandler struct { chatService *ChatService - chatMessageService *ChatMessageService } -func NewChatHandler(chatService *ChatService, chat_msg *ChatMessageService) *ChatHandler { +func NewChatHandler(chatService *ChatService ) *ChatHandler { return &ChatHandler{ chatService: chatService, - chatMessageService: chat_msg, } } @@ -192,7 +190,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } if existingPrompt { - _, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, int32(userIDInt)) + _, err := h.chatService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, int32(userIDInt)) if err != nil { http.Error(w, fmt.Errorf("fail to create message: %w", err).Error(), http.StatusInternalServerError) @@ -222,7 +220,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr if shouldReturn { return } - m, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) + m, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) log.Println(m) if err != nil { @@ -237,7 +235,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr return } // insert ChatMessage into database - _, err := h.chatMessageService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) + _, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) if err != nil { RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to create message: %w", err).Error(), nil) diff --git a/api/chat_main_service.go b/api/chat_main_service.go index 4920091e..87153909 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -294,3 +294,24 @@ func (s *ChatService) CreateChatPromptSimple(chatSessionUuid string, newQuestion }) return chatPrompt, err } + + +// CreateChatMessage creates a new chat message. +func (s *ChatService) CreateChatMessageSimple(ctx context.Context, sessionUuid, uuid, role, content string, userId int32) (sqlc_queries.ChatMessage, error) { + + chatMessage := sqlc_queries.CreateChatMessageParams{ + ChatSessionUuid: sessionUuid, + Uuid: uuid, + Role: role, + Content: content, + UserID: userId, + CreatedBy: userId, + UpdatedBy: userId, + Raw: json.RawMessage([]byte("{}")), + } + message, err := s.q.CreateChatMessage(ctx, chatMessage) + if err != nil { + return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to create message %w", err) + } + return message, nil +} \ No newline at end of file diff --git a/api/chat_message_service.go b/api/chat_message_service.go index fabdd4ea..7c64cdc1 100644 --- a/api/chat_message_service.go +++ b/api/chat_message_service.go @@ -29,25 +29,6 @@ func (s *ChatMessageService) CreateChatMessage(ctx context.Context, message_para return message, nil } -// CreateChatMessage creates a new chat message. -func (s *ChatMessageService) CreateChatMessageSimple(ctx context.Context, sessionUuid, uuid, role, content string, userId int32) (sqlc_queries.ChatMessage, error) { - - chatMessage := sqlc_queries.CreateChatMessageParams{ - ChatSessionUuid: sessionUuid, - Uuid: uuid, - Role: role, - Content: content, - UserID: userId, - CreatedBy: userId, - UpdatedBy: userId, - Raw: json.RawMessage([]byte("{}")), - } - message, err := s.q.CreateChatMessage(ctx, chatMessage) - if err != nil { - return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to create message %w", err) - } - return message, nil -} // GetChatMessageByID returns a chat message by ID. func (s *ChatMessageService) GetChatMessageByID(ctx context.Context, id int32) (sqlc_queries.ChatMessage, error) { diff --git a/api/main.go b/api/main.go index a72a20f9..21ec7267 100644 --- a/api/main.go +++ b/api/main.go @@ -181,7 +181,7 @@ func main() { chatService := NewChatService(sqlc_q) // create a new ChatHandler instance - chatHandler := NewChatHandler(chatService, chatMessageService) + chatHandler := NewChatHandler(chatService) // regiser the ChatHandler with the router chatHandler.Register(router) From 35cb422bca0250ea939a1eda6db7e83baeeacdb0 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 16:23:14 +0800 Subject: [PATCH 09/10] update --- api/chat_main_handler.go | 19 ++++++++----------- api/util.go | 10 ++++++++++ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index 49c0194a..fc963f3b 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -20,12 +20,12 @@ import ( ) type ChatHandler struct { - chatService *ChatService + chatService *ChatService } -func NewChatHandler(chatService *ChatService ) *ChatHandler { +func NewChatHandler(chatService *ChatService) *ChatHandler { return &ChatHandler{ - chatService: chatService, + chatService: chatService, } } @@ -120,13 +120,10 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr chatSessionUuid := req.SessionUuid chatUuid := req.ChatUuid newQuestion := req.Prompt - ctx := r.Context() - userIDStr := ctx.Value(userContextKey).(string) - userIDInt, err := strconv.Atoi(userIDStr) - userID := int32(userIDInt) + userID, err := getUserID(ctx) if err != nil { - RespondWithError(w, http.StatusBadRequest, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", nil) + RespondWithError(w, http.StatusBadRequest, err.Error(), err) return } @@ -190,7 +187,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr } if existingPrompt { - _, err := h.chatService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, int32(userIDInt)) + _, err := h.chatService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, userID) if err != nil { http.Error(w, fmt.Errorf("fail to create message: %w", err).Error(), http.StatusInternalServerError) @@ -220,7 +217,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr if shouldReturn { return } - m, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) + m, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, userID) log.Println(m) if err != nil { @@ -235,7 +232,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr return } // insert ChatMessage into database - _, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, int32(userIDInt)) + _, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, userID) if err != nil { RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to create message: %w", err).Error(), nil) diff --git a/api/util.go b/api/util.go index d9c6840a..1c27bb92 100644 --- a/api/util.go +++ b/api/util.go @@ -1,8 +1,11 @@ package main import ( + "context" "encoding/json" + "fmt" "net/http" + "strconv" ) // allocation free version @@ -17,6 +20,13 @@ func firstN(s string, n int) string { return s } +func getUserID(ctx context.Context) (int32, error) { + userIDStr := ctx.Value(userContextKey).(string) + userIDInt, err := strconv.Atoi(userIDStr) + userID := int32(userIDInt) + return userID, fmt.Errorf("Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID %w", err) +} + func RespondWithError(w http.ResponseWriter, code int, message string, details interface{}) { w.WriteHeader(code) json.NewEncoder(w).Encode(ErrorResponse{Code: code, Message: message, Details: details}) From 99b8f160cffc87796499215abf7c7b50a997e63c Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 23 Mar 2023 16:34:13 +0800 Subject: [PATCH 10/10] update --- api/chat_main_handler.go | 6 +----- api/chat_main_service.go | 12 ++++++++++-- api/util.go | 5 ++++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/api/chat_main_handler.go b/api/chat_main_handler.go index fc963f3b..4e0504ec 100644 --- a/api/chat_main_handler.go +++ b/api/chat_main_handler.go @@ -155,11 +155,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr return } // Update the chatMessage content with chatUuid with new answer - err = h.chatService.q.UpdateChatMessageContent(ctx, - sqlc_queries.UpdateChatMessageContentParams{ - Uuid: chatUuid, - Content: answerText, - }) + err = h.chatService.UpdateChatMessageContent(ctx, chatUuid, answerText) if err != nil { RespondWithError(w, http.StatusInternalServerError, "Update chat message error", err) diff --git a/api/chat_main_service.go b/api/chat_main_service.go index 87153909..97f4ffa7 100644 --- a/api/chat_main_service.go +++ b/api/chat_main_service.go @@ -295,7 +295,6 @@ func (s *ChatService) CreateChatPromptSimple(chatSessionUuid string, newQuestion return chatPrompt, err } - // CreateChatMessage creates a new chat message. func (s *ChatService) CreateChatMessageSimple(ctx context.Context, sessionUuid, uuid, role, content string, userId int32) (sqlc_queries.ChatMessage, error) { @@ -314,4 +313,13 @@ func (s *ChatService) CreateChatMessageSimple(ctx context.Context, sessionUuid, return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to create message %w", err) } return message, nil -} \ No newline at end of file +} + +// UpdateChatMessageContent +func (s *ChatService) UpdateChatMessageContent(ctx context.Context, uuid, content string) (error) { + err := s.q.UpdateChatMessageContent(ctx, sqlc_queries.UpdateChatMessageContentParams{ + Uuid: uuid, + Content: content, + }) + return err +} diff --git a/api/util.go b/api/util.go index 1c27bb92..39da8433 100644 --- a/api/util.go +++ b/api/util.go @@ -23,8 +23,11 @@ func firstN(s string, n int) string { func getUserID(ctx context.Context) (int32, error) { userIDStr := ctx.Value(userContextKey).(string) userIDInt, err := strconv.Atoi(userIDStr) + if err != nil { + return 0, fmt.Errorf("Error: '"+userIDStr+"' is not a valid user ID. should be a numeric value: %w", err) + } userID := int32(userIDInt) - return userID, fmt.Errorf("Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID %w", err) + return userID, nil } func RespondWithError(w http.ResponseWriter, code int, message string, details interface{}) {