Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

code clean #27

Merged
merged 10 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 28 additions & 140 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"strings"

uuid "github.com/iris-contrib/go.uuid"
"github.com/samber/lo"
openai "github.com/sashabaranov/go-openai"
"github.com/swuecho/chatgpt_backend/sqlc_queries"

swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -61,8 +60,8 @@ type ChatCompletionResponse struct {

type Choice struct {
Message openai.ChatCompletionMessage `json:"message"`
FinishReason interface{} `json:"finish_reason"`
Index int `json:"index"`
FinishReason interface{} `json:"finish_reason"`
Index int `json:"index"`
}

type Message struct {
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -71,12 +70,11 @@ type Message struct {
}

type OpenaiChatRequest struct {
Model string `json:"model"`
Model string `json:"model"`
Messages []openai.ChatCompletionMessage `json:"messages"`

}

func NewUserMessage(content string) openai.ChatCompletionMessage{
func NewUserMessage(content string) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: "user", Content: content}
}

swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -85,30 +83,27 @@ func (h *ChatHandler) chatHandler(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Invalid request body: %v", err)
json.NewEncoder(w).Encode(map[string]interface{}{"error": "Invalid request body"})
return
}

chatSessionUuid := req.SessionUuid
chatUuid := req.ChatUuid
newQuestion := req.Prompt
log.Printf("Received prompt: %s\n", newQuestion)
defer r.Body.Close()
ctx := r.Context()
userIDStr := ctx.Value(userContextKey).(string)
userIDInt, err := strconv.Atoi(userIDStr)
if err != nil {
http.Error(w, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", http.StatusBadRequest)
userIDStr, ok := ctx.Value(userContextKey).(string)
if !ok {
RespondWithError(w, http.StatusInternalServerError, err.Error(), err)
return
}
answer_msg, err := h.chatService.Chat(chatSessionUuid, chatUuid, newQuestion, int32(userIDInt))

userIDInt, _ := strconv.Atoi(userIDStr)
answerMsg, err := h.chatService.Chat(req.SessionUuid, req.ChatUuid, req.Prompt, int32(userIDInt))

if err != nil {
fmt.Fprintf(w, "problem in chat: %v", err)
RespondWithError(w, http.StatusInternalServerError, err.Error(), err)
return
}

// Send the response as JSON
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"status": "Success", "text": answer_msg.Content, "chatUuid": answer_msg.Uuid})
json.NewEncoder(w).Encode(map[string]interface{}{"status": "Success", "text": answerMsg.Content, "chatUuid": answerMsg.Uuid})
}

// OpenAIChatCompletionAPIWithStreamHandler is an HTTP handler that sends the stream to the client as Server-Sent Events (SSE)
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -125,13 +120,10 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
chatSessionUuid := req.SessionUuid
chatUuid := req.ChatUuid
newQuestion := req.Prompt

ctx := r.Context()
userIDStr := ctx.Value(userContextKey).(string)
userIDInt, err := strconv.Atoi(userIDStr)
userID := int32(userIDInt)
userID, err := getUserID(ctx)
if err != nil {
RespondWithError(w, http.StatusBadRequest, "Error: '"+userIDStr+"' is not a valid user ID. Please enter a valid user ID.", nil)
RespondWithError(w, http.StatusBadRequest, err.Error(), err)
return
}

Expand All @@ -148,7 +140,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
}

// Send the response as JSON
chatCompletionMessages, err := getAskMessages(h, ctx, w, chat_session, chatUuid, true)
chatCompletionMessages, err := h.chatService.getAskMessages(chat_session, chatUuid, true)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, "Get chat message error", err)
return
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -158,16 +150,12 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
if chat_session.Debug {
log.Printf("%+v\n", chatCompletionMessages)
}
answerText, _, shouldReturn := chat_stream(ctx, chat_session, chatCompletionMessages, w)
answerText, _, shouldReturn := chat_stream(w, chat_session, chatCompletionMessages)
if shouldReturn {
return
}
swuecho marked this conversation as resolved.
Show resolved Hide resolved
// Update the chatMessage content with chatUuid with new answer
err = h.chatService.q.UpdateChatMessageContent(ctx,
sqlc_queries.UpdateChatMessageContentParams{
Uuid: chatUuid,
Content: answerText,
})
err = h.chatService.UpdateChatMessageContent(ctx, chatUuid, answerText)

if err != nil {
RespondWithError(w, http.StatusInternalServerError, "Update chat message error", err)
Expand All @@ -176,22 +164,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
return
}

////

// no session exists
//
// if no session chat_created, create new chat_session with $uuid
// create a new prompt with topic = $uuid, role = "system", content= req.Prompt

// if session avaiable,
// GetChatPromptBySessionID and create Message from Prompt
// GetLatestMessagesBySessionID and create Messsage(s) from messages

// Check if the chat session exists

// no session exists
// create session and prompt

chatSession, err := h.chatService.q.GetChatSessionByUUID(ctx, chatSessionUuid)

if err != nil {
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -200,7 +172,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr

existingPrompt := true

log.Println(chatSessionUuid)
_, err = h.chatService.q.GetOneChatPromptBySessionUUID(ctx, chatSessionUuid)

if err != nil {
Expand All @@ -212,40 +183,20 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
}

if existingPrompt {
_, err := h.chatService.q.CreateChatMessage(ctx,
sqlc_queries.CreateChatMessageParams{
ChatSessionUuid: chatSession.Uuid,
Uuid: chatUuid,
Role: "user",
Content: newQuestion,
Raw: json.RawMessage([]byte("{}")),
UserID: userID,
CreatedBy: userID,
UpdatedBy: userID,
})
_, err := h.chatService.CreateChatMessageSimple(ctx, chatSession.Uuid, chatUuid, "user", newQuestion, userID)

if err != nil {
http.Error(w, fmt.Errorf("fail to create message: %w", err).Error(), http.StatusInternalServerError)
}
} else {
uuidVar, _ := uuid.NewV4()
chatPrompt, err := h.chatService.q.CreateChatPrompt(ctx,
sqlc_queries.CreateChatPromptParams{
Uuid: uuidVar.String(),
ChatSessionUuid: chatSessionUuid,
Role: "system",
Content: newQuestion,
UserID: userID,
CreatedBy: userID,
UpdatedBy: userID,
})
chatPrompt, err := h.chatService.CreateChatPromptSimple(chatSessionUuid, newQuestion, userID)
if err != nil {
http.Error(w, fmt.Errorf("fail to create prompt: %w", err).Error(), http.StatusInternalServerError)
}
log.Printf("%+v\n", chatPrompt)
}

msgs, err := getAskMessages(h, ctx, w, chatSession, chatUuid, false)
msgs, err := h.chatService.getAskMessages(chatSession, chatUuid, false)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to collect messages: %w", err).Error(), err)
return
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -262,20 +213,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
if shouldReturn {
return
}
// insert ChatMessage into database
chatMessageParams := sqlc_queries.CreateChatMessageParams{
ChatSessionUuid: chatSessionUuid,
Uuid: answerID,
Role: "assistant",
Content: answerText,
UserID: int32(userIDInt),
CreatedBy: int32(userIDInt),
UpdatedBy: int32(userIDInt),
Raw: json.RawMessage([]byte("{}")),
}
log.Println(chatMessageParams)

m, err := h.chatService.q.CreateChatMessage(ctx, chatMessageParams)
m, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, userID)

log.Println(m)
swuecho marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -285,23 +223,12 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
} else {

// Set up SSE headers
answerText, answerID, shouldReturn := chat_stream(ctx, chatSession, msgs, w)
answerText, answerID, shouldReturn := chat_stream(w, chatSession, msgs)
if shouldReturn {
return
}
swuecho marked this conversation as resolved.
Show resolved Hide resolved
// insert ChatMessage into database
chatMessage := sqlc_queries.CreateChatMessageParams{
Uuid: answerID,
ChatSessionUuid: chatSessionUuid,
Role: "assistant",
UserID: int32(userIDInt),
Content: answerText,
CreatedBy: int32(userIDInt),
UpdatedBy: int32(userIDInt),
Raw: json.RawMessage([]byte("{}")),
}

_, err := h.chatService.q.CreateChatMessage(ctx, chatMessage)
_, err := h.chatService.CreateChatMessageSimple(ctx, chatSessionUuid, answerID, "assistant", answerText, userID)

if err != nil {
RespondWithError(w, http.StatusInternalServerError, fmt.Errorf("fail to create message: %w", err).Error(), nil)
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -310,47 +237,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr

}

func getAskMessages(h *ChatHandler, ctx context.Context, w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) {
chatSessionUuid := chatSession.Uuid

lastN := chatSession.MaxLength
if chatSession.MaxLength == 0 {
lastN = 10
}

chat_prompts, err := h.chatService.q.GetChatPromptsBySessionUUID(ctx, chatSessionUuid)

if err != nil {
return nil, fmt.Errorf("fail to get prompt: %w", err)
}

var chat_massages []sqlc_queries.ChatMessage
if regenerate {
chat_massages, err = h.chatService.q.GetLastNChatMessages(ctx,
sqlc_queries.GetLastNChatMessagesParams{
Uuid: chatUuid,
Limit: lastN,
})

} else {
chat_massages, err = h.chatService.q.GetLatestMessagesBySessionUUID(ctx,
sqlc_queries.GetLatestMessagesBySessionUUIDParams{ChatSessionUuid: chatSession.Uuid, Limit: lastN})
}

if err != nil {
return nil, fmt.Errorf("fail to get messages: %w", err)
}
chat_prompt_msgs := lo.Map(chat_prompts, func(m sqlc_queries.ChatPrompt, _ int) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}
})
chat_message_msgs := lo.Map(chat_massages, func(m sqlc_queries.ChatMessage, _ int) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}
})
msgs := append(chat_prompt_msgs, chat_message_msgs...)
return msgs, nil
}

func chat_stream(ctx context.Context, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage, w http.ResponseWriter) (string, string, bool) {
func chat_stream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage) (string, string, bool) {
apiKey := appConfig.OPENAI.API_KEY

client := openai.NewClient(apiKey)
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -366,6 +253,7 @@ func chat_stream(ctx context.Context, chatSession sqlc_queries.ChatSession, chat
// N: n,
Stream: true,
}
ctx := context.Background()
stream, err := client.CreateChatCompletionStream(ctx, openai_req)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("CompletionStream error: %v", err), nil)
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
88 changes: 88 additions & 0 deletions api/chat_main_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log"
"net/http"
"strings"
"time"

uuid "github.com/iris-contrib/go.uuid"
"github.com/samber/lo"
swuecho marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -235,3 +236,90 @@ func GetAiAnswerOpenApi(msgs []openai.ChatCompletionMessage) (ChatCompletionResp
}
return aiAnswer, nil
}

func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatUuid string, regenerate bool) ([]openai.ChatCompletionMessage, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

chatSessionUuid := chatSession.Uuid

lastN := chatSession.MaxLength
if chatSession.MaxLength == 0 {
lastN = 10
}

chat_prompts, err := s.q.GetChatPromptsBySessionUUID(ctx, chatSessionUuid)

if err != nil {
return nil, fmt.Errorf("fail to get prompt: %w", err)
}

var chat_massages []sqlc_queries.ChatMessage
if regenerate {
chat_massages, err = s.q.GetLastNChatMessages(ctx,
sqlc_queries.GetLastNChatMessagesParams{
Uuid: chatUuid,
Limit: lastN,
})

} else {
chat_massages, err = s.q.GetLatestMessagesBySessionUUID(ctx,
sqlc_queries.GetLatestMessagesBySessionUUIDParams{ChatSessionUuid: chatSession.Uuid, Limit: lastN})
}

if err != nil {
return nil, fmt.Errorf("fail to get messages: %w", err)
}
chat_prompt_msgs := lo.Map(chat_prompts, func(m sqlc_queries.ChatPrompt, _ int) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}
})
chat_message_msgs := lo.Map(chat_massages, func(m sqlc_queries.ChatMessage, _ int) openai.ChatCompletionMessage {
return openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}
})
msgs := append(chat_prompt_msgs, chat_message_msgs...)
return msgs, nil
}

swuecho marked this conversation as resolved.
Show resolved Hide resolved
func (s *ChatService) CreateChatPromptSimple(chatSessionUuid string, newQuestion string, userID int32) (sqlc_queries.ChatPrompt, error) {
uuidVar, _ := uuid.NewV4()
chatPrompt, err := s.q.CreateChatPrompt(context.Background(),
sqlc_queries.CreateChatPromptParams{
Uuid: uuidVar.String(),
ChatSessionUuid: chatSessionUuid,
Role: "system",
Content: newQuestion,
UserID: userID,
CreatedBy: userID,
UpdatedBy: userID,
})
return chatPrompt, err
}

// CreateChatMessage creates a new chat message.
func (s *ChatService) CreateChatMessageSimple(ctx context.Context, sessionUuid, uuid, role, content string, userId int32) (sqlc_queries.ChatMessage, error) {

chatMessage := sqlc_queries.CreateChatMessageParams{
ChatSessionUuid: sessionUuid,
Uuid: uuid,
Role: role,
Content: content,
UserID: userId,
CreatedBy: userId,
UpdatedBy: userId,
Raw: json.RawMessage([]byte("{}")),
}
message, err := s.q.CreateChatMessage(ctx, chatMessage)
if err != nil {
return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to create message %w", err)
}
return message, nil
}

// UpdateChatMessageContent
func (s *ChatService) UpdateChatMessageContent(ctx context.Context, uuid, content string) (error) {
err := s.q.UpdateChatMessageContent(ctx, sqlc_queries.UpdateChatMessageContentParams{
Uuid: uuid,
Content: content,
})
return err
}
Loading