From a6f47c3ba6800dbcb9a1a72d57288865aee504f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Serta=C3=A7=20=C3=96zercan?= <852750+sozercan@users.noreply.github.com> Date: Sun, 10 Mar 2024 14:36:46 -0700 Subject: [PATCH] chore: use toolcalls (#116) Signed-off-by: Sertac Ozercan --- cmd/cli/completion.go | 13 +++++------ cmd/cli/functions.go | 8 +++---- cmd/cli/openai.go | 50 ++++++++++++++++++++++++------------------- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/cmd/cli/completion.go b/cmd/cli/completion.go index 14b36aa..3f3b0a6 100644 --- a/cmd/cli/completion.go +++ b/cmd/cli/completion.go @@ -12,8 +12,11 @@ import ( openai "github.com/sashabaranov/go-openai" "github.com/sethvargo/go-retry" + log "github.com/sirupsen/logrus" ) +const maxRetries = 10 + type oaiClients struct { openAIClient openai.Client } @@ -72,20 +75,18 @@ func gptCompletion(ctx context.Context, client oaiClients, prompts []string) (st var resp string var err error - r := retry.WithMaxRetries(10, retry.NewExponential(1*time.Second)) + r := retry.WithMaxRetries(maxRetries, retry.NewExponential(1*time.Second)) if err := retry.Do(ctx, r, func(ctx context.Context) error { resp, err = client.openaiGptChatCompletion(ctx, &prompt, temp) - requestErr := &openai.RequestError{} + requestErr := &openai.APIError{} if errors.As(err, &requestErr) { switch requestErr.HTTPStatusCode { - case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusServiceUnavailable: + case http.StatusTooManyRequests, http.StatusRequestTimeout, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + log.Debugf("retrying due to status code %d: %s", requestErr.HTTPStatusCode, requestErr.Message) return retry.RetryableError(err) } } - if err != nil { - return err - } return nil }); err != nil { return "", err diff --git a/cmd/cli/functions.go b/cmd/cli/functions.go index 41d4b33..6f1577b 100644 --- a/cmd/cli/functions.go +++ b/cmd/cli/functions.go @@ -69,17 +69,17 @@ func (s *schema) Run() (content string, err error) { return string(schemaBytes), nil } -func funcCall(call *openai.FunctionCall) (string, error) { - switch call.Name { +func callTool(toolCall openai.ToolCall) (string, error) { + switch toolCall.Function.Name { case findSchemaNames.Name: var f schemaNames - if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil { return "", err } return f.Run() case getSchema.Name: var f schema - if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil { return "", err } return f.Run() diff --git a/cmd/cli/openai.go b/cmd/cli/openai.go index 2b9fb59..44fe1ac 100644 --- a/cmd/cli/openai.go +++ b/cmd/cli/openai.go @@ -9,24 +9,24 @@ import ( log "github.com/sirupsen/logrus" ) -type functionCallType string +type toolChoiceType string const ( - fnCallAuto functionCallType = "auto" - fnCallNone functionCallType = "none" + toolChoiceAuto toolChoiceType = "auto" + toolChoiceNone toolChoiceType = "none" ) func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *strings.Builder, temp float32) (string, error) { var ( - resp openai.ChatCompletionResponse - req openai.ChatCompletionRequest - funcName *openai.FunctionCall - content string - err error + resp openai.ChatCompletionResponse + req openai.ChatCompletionRequest + content string + err error ) // if we are using the k8s API, we need to call the functions - fnCallType := fnCallAuto + toolChoiseType := toolChoiceAuto + for { prompt.WriteString(content) log.Debugf("prompt: %s", prompt.String()) @@ -44,12 +44,17 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *string } if *usek8sAPI { - // TODO: migrate to tools api - req.Functions = []openai.FunctionDefinition{ // nolint:staticcheck - findSchemaNames, - getSchema, + req.Tools = []openai.Tool{ + { + Type: "function", + Function: &findSchemaNames, + }, + { + Type: "function", + Function: &getSchema, + }, } - req.FunctionCall = fnCallType // nolint:staticcheck + req.ToolChoice = toolChoiseType } resp, err = c.openAIClient.CreateChatCompletion(ctx, req) @@ -57,17 +62,18 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *string return "", err } - funcName = resp.Choices[0].Message.FunctionCall - // if there is no function call, we are done - if funcName == nil { + if len(resp.Choices[0].Message.ToolCalls) == 0 { break } - log.Debugf("calling function: %s", funcName.Name) - // if there is a function call, we need to call it and get the result - content, err = funcCall(funcName) - if err != nil { - return "", err + for _, tool := range resp.Choices[0].Message.ToolCalls { + log.Debugf("calling tool: %s", tool.Function.Name) + + // if there is a tool call, we need to call it and get the result + content, err = callTool(tool) + if err != nil { + return "", err + } } }