Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: context deadline is too short #31

Merged
merged 6 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
return
}

// Set up SSE headers
if chat_session.Debug {
log.Printf("%+v\n", chatCompletionMessages)
}
Expand Down Expand Up @@ -210,7 +209,7 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
http.Error(w, "No messages found", http.StatusNotFound)
}
if msgs[0].Content == "test_demo_bestqa" || msgs[len(msgs)-1].Content == "test_demo_bestqa" {
answerText, answerID, shouldReturn := test_replay(w)
answerText, answerID, shouldReturn := test_replay(w, chatSession, msgs)
if shouldReturn {
return
}
Expand Down Expand Up @@ -241,18 +240,8 @@ func (h *ChatHandler) OpenAIChatCompletionAPIWithStreamHandler(w http.ResponseWr
func chat_stream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage) (string, string, bool) {
client := openai.NewClient(appConfig.OPENAI.API_KEY)

openai_req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: chat_compeletion_messages,
MaxTokens: int(chatSession.MaxTokens),
Temperature: float32(chatSession.Temperature),
TopP: float32(chatSession.TopP),
// PresencePenalty: presencePenalty,
// FrequencyPenalty: frequencyPenalty,
// N: n,
Stream: true,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
openai_req := newChatCompletionRequest(chatSession, chat_compeletion_messages)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
defer cancel()
stream, err := client.CreateChatCompletionStream(ctx, openai_req)
if err != nil {
Expand Down Expand Up @@ -315,7 +304,7 @@ func chat_stream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, ch
return answer, answer_id, false
}

func test_replay(w http.ResponseWriter) (string, string, bool) {
func test_replay(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage) (string, string, bool) {
//message := Message{Role: "assitant", Content:}
uuid, _ := uuid.NewV4()
setSSEHeader(w)
Expand All @@ -332,9 +321,36 @@ func test_replay(w http.ResponseWriter) (string, string, bool) {
data, _ := json.Marshal(resp)
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()

if chatSession.Debug {
// PresencePenalty: presencePenalty,
// FrequencyPenalty: frequencyPenalty,
// N: n,
openai_req := newChatCompletionRequest(chatSession, chat_compeletion_messages )
req_j, _ := json.Marshal(openai_req)
log.Println(string(req_j))
answer = answer + "\n" + string(req_j)
req_as_resp := constructChatCompletionStreamReponse(answer_id, answer)
data, _ := json.Marshal(req_as_resp)
fmt.Fprintf(w, "data: %s\n\n", string(data))
flusher.Flush()
}
return answer, answer_id, false
}

func newChatCompletionRequest(chatSession sqlc_queries.ChatSession, chat_compeletion_messages []openai.ChatCompletionMessage) openai.ChatCompletionRequest {
openai_req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: chat_compeletion_messages,
MaxTokens: int(chatSession.MaxTokens),
Temperature: float32(chatSession.Temperature),
TopP: float32(chatSession.TopP),

Stream: true,
}
return openai_req
}

func constructChatCompletionStreamReponse(answer_id string, answer string) openai.ChatCompletionStreamResponse {
resp := openai.ChatCompletionStreamResponse{
ID: answer_id,
Expand Down
1 change: 1 addition & 0 deletions api/chat_main_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ func (s *ChatService) getAskMessages(chatSession sqlc_queries.ChatSession, chatU
sqlc_queries.GetLastNChatMessagesParams{
Uuid: chatUuid,
Limit: lastN,
ChatSessionUuid: chatSessionUuid,
})

} else {
Expand Down
31 changes: 7 additions & 24 deletions api/chat_message_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func (s *ChatMessageService) CreateChatMessage(ctx context.Context, message_para
return message, nil
}


// GetChatMessageByID returns a chat message by ID.
func (s *ChatMessageService) GetChatMessageByID(ctx context.Context, id int32) (sqlc_queries.ChatMessage, error) {
message, err := s.q.GetChatMessageByID(ctx, id)
Expand All @@ -43,7 +42,7 @@ func (s *ChatMessageService) GetChatMessageByID(ctx context.Context, id int32) (
func (s *ChatMessageService) UpdateChatMessage(ctx context.Context, message_params sqlc_queries.UpdateChatMessageParams) (sqlc_queries.ChatMessage, error) {
message_u, err := s.q.UpdateChatMessage(ctx, message_params)
if err != nil {
return sqlc_queries.ChatMessage{}, errors.New("failed to update message")
return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to update message %w", err)
}
return message_u, nil
}
Expand All @@ -52,7 +51,7 @@ func (s *ChatMessageService) UpdateChatMessage(ctx context.Context, message_para
func (s *ChatMessageService) DeleteChatMessage(ctx context.Context, id int32) error {
err := s.q.DeleteChatMessage(ctx, id)
if err != nil {
return errors.New("failed to delete message")
return fmt.Errorf("failed to delete message %w", err)
}
return nil
}
Expand All @@ -61,7 +60,7 @@ func (s *ChatMessageService) DeleteChatMessage(ctx context.Context, id int32) er
func (s *ChatMessageService) DeleteChatMessageByUUID(ctx context.Context, uuid string) error {
err := s.q.DeleteChatMessageByUUID(ctx, uuid)
if err != nil {
return errors.New("failed to delete message")
return fmt.Errorf("failed to delete message %w", err)
}
return nil
}
Expand All @@ -70,7 +69,7 @@ func (s *ChatMessageService) DeleteChatMessageByUUID(ctx context.Context, uuid s
func (s *ChatMessageService) GetAllChatMessages(ctx context.Context) ([]sqlc_queries.ChatMessage, error) {
messages, err := s.q.GetAllChatMessages(ctx)
if err != nil {
return nil, errors.New("failed to retrieve messages")
return nil, fmt.Errorf("failed to retrieve messages %w", err)
}
return messages, nil
}
Expand Down Expand Up @@ -120,7 +119,7 @@ func (s *ChatMessageService) GetChatMessageByUUID(ctx context.Context, uuid stri
func (s *ChatMessageService) UpdateChatMessageByUUID(ctx context.Context, message_params sqlc_queries.UpdateChatMessageByUUIDParams) (sqlc_queries.ChatMessage, error) {
message_u, err := s.q.UpdateChatMessageByUUID(ctx, message_params)
if err != nil {
return sqlc_queries.ChatMessage{}, errors.New("failed to update message")
return sqlc_queries.ChatMessage{}, fmt.Errorf("failed to update message %w", err)
}
return message_u, nil
}
Expand All @@ -134,7 +133,7 @@ func (s *ChatMessageService) GetChatMessagesBySessionUUID(ctx context.Context, u
}
message, err := s.q.GetChatMessagesBySessionUUID(ctx, param)
if err != nil {
return []sqlc_queries.ChatMessage{}, errors.New("failed to retrieve message")
return []sqlc_queries.ChatMessage{}, fmt.Errorf("failed to retrieve message %w", err)
}
return message, nil
}
Expand Down Expand Up @@ -190,26 +189,10 @@ func (s *ChatMessageService) GetChatHistoryBySessionUUID(ctx context.Context, uu
return msgs, nil
}

// GetLastNChatMessagesByUUID returns last N chat message related by uuid.
func (s *ChatMessageService) GetLastNChatMessages(ctx context.Context, uuid string, n int32) ([]sqlc_queries.ChatMessage, error) {
param := sqlc_queries.GetLastNChatMessagesParams{
Uuid: uuid,
Limit: n,
}
message, err := s.q.GetLastNChatMessages(ctx, param)
if err != nil {
return []sqlc_queries.ChatMessage{}, errors.New("failed to retrieve message")
}
return message, nil
}

// DeleteChatMessagesBySesionUUID deletes chat messages by session uuid.
func (s *ChatMessageService) DeleteChatMessagesBySesionUUID(ctx context.Context, uuid string) error {
err := s.q.DeleteChatMessagesBySesionUUID(ctx, uuid)
if err != nil {
return errors.New("failed to delete message")
}
return nil
return err
}

func (s *ChatMessageService) GetChatMessagesCount(ctx context.Context, userID int32) (int32, error) {
Expand Down
4 changes: 2 additions & 2 deletions api/sqlc/queries/chat_message.sql
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ FROM chat_message
WHERE chat_message.id in (
SELECT id
FROM chat_message cm
WHERE cm.chat_session_uuid = (SELECT chat_session_uuid FROM chat_message WHERE chat_message.uuid = $1)
AND cm.created_at < (SELECT created_at FROM chat_message WHERE chat_message.uuid = $1)
WHERE cm.chat_session_uuid = $3
AND cm.id < (SELECT id FROM chat_message WHERE chat_message.uuid = $1)
ORDER BY cm.created_at DESC
LIMIT $2
)
Expand Down
11 changes: 6 additions & 5 deletions api/sqlc_queries/chat_message.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions web/src/views/chat/components/Session/SessionConfig.vue
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,19 @@ watch([slider, temperature, maxTokens, topP, debug], ([newValueSlider, newValueT
</script>

<template>
<!-- https://platform.openai.com/playground?mode=chat -->
<NCard style="width: 600px" title="会话设置" :bordered="false" size="huge" role="dialog" aria-modal="true">
<div>{{ $t('chat.slider') }}: {{ slider }}</div>
<NSlider v-model:value="slider" :min="1" :max="20" :tooltip="false" />

<div>{{ $t('chat.temperature') }}: {{ temperature }}</div>
<NSlider v-model:value="temperature" :min="0.1" :max="2" :step="0.1" :tooltip="false" />
<NSlider v-model:value="temperature" :min="0.1" :max="1" :step="0.01" :tooltip="false" />

<div>{{ $t('chat.topP') }}: {{ topP }}</div>
<NSlider v-model:value="topP" :min="0" :max="1" :step="0.1" :tooltip="false" />
<NSlider v-model:value="topP" :min="0" :max="1" :step="0.01" :tooltip="false" />

<div>{{ $t('chat.maxTokens') }}: {{ maxTokens }}</div>
<NSlider v-model:value="maxTokens" :min="256" :max="4096" :step="16" :tooltip="false" />
<NSlider v-model:value="maxTokens" :min="256" :max="2048" :step="16" :tooltip="false" />
<div> {{ $t('chat.debug') }}</div>
<NSwitch v-model:value="debug">
<template #checked>
Expand Down
43 changes: 23 additions & 20 deletions web/src/views/chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,28 @@ async function onConversationStream() {
// Check if the chunk is not empty
if (chunk) {
// Parse the JSON data chunk
const data = JSON.parse(chunk)
const answer = data.choices[0].delta.content
const answer_uuid = data.id.replace('chatcmpl-', '') // use answer id as uuid
updateChat(
sessionUuid,
dataSources.value.length - 1,
{
uuid: answer_uuid,
dateTime: new Date().toLocaleString(),
text: answer,
inversion: false,
error: false,
loading: false,
conversationOptions: { conversationId: data.conversationId, parentMessageId: data.id },
requestOptions: { prompt: message, options: { ...options } },
},
)
try {
const data = JSON.parse(chunk)
const answer = data.choices[0].delta.content
const answer_uuid = data.id.replace('chatcmpl-', '') // use answer id as uuid
updateChat(
sessionUuid,
dataSources.value.length - 1,
{
uuid: answer_uuid,
dateTime: new Date().toLocaleString(),
text: answer,
inversion: false,
error: false,
loading: false,
conversationOptions: { conversationId: data.conversationId, parentMessageId: data.id },
requestOptions: { prompt: message, options: { ...options } },
},
)
}
catch (error) {
console.log(error)
}
}
},
)
Expand Down Expand Up @@ -422,9 +427,7 @@ onUnmounted(() => {
/>
<main class="flex-1 overflow-hidden">
<NModal ref="sessionConfigModal" v-model:show="showModal">
<SessionConfig
ref="sessionConfig" :uuid="sessionUuid"
/>
<SessionConfig ref="sessionConfig" :uuid="sessionUuid" />
</NModal>
<div id="scrollRef" ref="scrollRef" class="h-full overflow-hidden overflow-y-auto">
<div
Expand Down