Skip to content

Commit

Permalink
chore: use toolcalls (#116)
Browse files Browse the repository at this point in the history
Signed-off-by: Sertac Ozercan <[email protected]>
  • Loading branch information
sozercan committed Mar 10, 2024
1 parent 0043908 commit a6f47c3
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
13 changes: 7 additions & 6 deletions cmd/cli/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ import (

openai "github.com/sashabaranov/go-openai"
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
)

const maxRetries = 10

type oaiClients struct {
openAIClient openai.Client
}
Expand Down Expand Up @@ -72,20 +75,18 @@ func gptCompletion(ctx context.Context, client oaiClients, prompts []string) (st

var resp string
var err error
r := retry.WithMaxRetries(10, retry.NewExponential(1*time.Second))
r := retry.WithMaxRetries(maxRetries, retry.NewExponential(1*time.Second))
if err := retry.Do(ctx, r, func(ctx context.Context) error {
resp, err = client.openaiGptChatCompletion(ctx, &prompt, temp)

requestErr := &openai.RequestError{}
requestErr := &openai.APIError{}
if errors.As(err, &requestErr) {
switch requestErr.HTTPStatusCode {
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusServiceUnavailable:
case http.StatusTooManyRequests, http.StatusRequestTimeout, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
log.Debugf("retrying due to status code %d: %s", requestErr.HTTPStatusCode, requestErr.Message)
return retry.RetryableError(err)
}
}
if err != nil {
return err
}
return nil
}); err != nil {
return "", err
Expand Down
8 changes: 4 additions & 4 deletions cmd/cli/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ func (s *schema) Run() (content string, err error) {
return string(schemaBytes), nil
}

func funcCall(call *openai.FunctionCall) (string, error) {
switch call.Name {
func callTool(toolCall openai.ToolCall) (string, error) {
switch toolCall.Function.Name {
case findSchemaNames.Name:
var f schemaNames
if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil {
return "", err
}
return f.Run()
case getSchema.Name:
var f schema
if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil {
return "", err
}
return f.Run()
Expand Down
50 changes: 28 additions & 22 deletions cmd/cli/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ import (
log "github.com/sirupsen/logrus"
)

type functionCallType string
type toolChoiceType string

const (
fnCallAuto functionCallType = "auto"
fnCallNone functionCallType = "none"
toolChoiceAuto toolChoiceType = "auto"
toolChoiceNone toolChoiceType = "none"
)

func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *strings.Builder, temp float32) (string, error) {
var (
resp openai.ChatCompletionResponse
req openai.ChatCompletionRequest
funcName *openai.FunctionCall
content string
err error
resp openai.ChatCompletionResponse
req openai.ChatCompletionRequest
content string
err error
)

// if we are using the k8s API, we need to call the functions
fnCallType := fnCallAuto
toolChoiseType := toolChoiceAuto

for {
prompt.WriteString(content)
log.Debugf("prompt: %s", prompt.String())
Expand All @@ -44,30 +44,36 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *string
}

if *usek8sAPI {
// TODO: migrate to tools api
req.Functions = []openai.FunctionDefinition{ // nolint:staticcheck
findSchemaNames,
getSchema,
req.Tools = []openai.Tool{
{
Type: "function",
Function: &findSchemaNames,
},
{
Type: "function",
Function: &getSchema,
},
}
req.FunctionCall = fnCallType // nolint:staticcheck
req.ToolChoice = toolChoiseType
}

resp, err = c.openAIClient.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}

funcName = resp.Choices[0].Message.FunctionCall
// if there is no function call, we are done
if funcName == nil {
if len(resp.Choices[0].Message.ToolCalls) == 0 {
break
}
log.Debugf("calling function: %s", funcName.Name)

// if there is a function call, we need to call it and get the result
content, err = funcCall(funcName)
if err != nil {
return "", err
for _, tool := range resp.Choices[0].Message.ToolCalls {
log.Debugf("calling tool: %s", tool.Function.Name)

// if there is a tool call, we need to call it and get the result
content, err = callTool(tool)
if err != nil {
return "", err
}
}
}

Expand Down

0 comments on commit a6f47c3

Please sign in to comment.