Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(functions): allow to set JSON matcher #2319

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
responses <- initialMessage

result, err := handleQuestion(config, req, ml, startupOptions, results, prompt)
result, err := handleQuestion(config, req, ml, startupOptions, results, result, prompt)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
Expand Down Expand Up @@ -470,7 +470,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup

switch {
case noActionsToRun:
result, err := handleQuestion(config, input, ml, startupOptions, results, predInput)
result, err := handleQuestion(config, input, ml, startupOptions, results, s, predInput)
if err != nil {
log.Error().Err(err).Msg("error handling question")
return
Expand Down Expand Up @@ -550,7 +550,14 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
}
}

func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, prompt string) (string, error) {
func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {

if len(funcResults) == 0 && result != "" {
log.Debug().Msgf("nothing function results but we had a message from the LLM")

return result, nil
}

log.Debug().Msgf("nothing to do, computing a reply")
arg := ""
if len(funcResults) > 0 {
Expand Down
84 changes: 58 additions & 26 deletions pkg/functions/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package functions

import (
"encoding/json"
"fmt"
"regexp"

"github.com/go-skynet/LocalAI/pkg/utils"
Expand All @@ -16,6 +17,8 @@ type FunctionsConfig struct {
NoGrammar bool `yaml:"no_grammar"`
ResponseRegex string `yaml:"response_regex"`

JSONRegexMatch string `yaml:"json_regex_match"`

// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
// This might be useful for certain models trained with the function name as the first token.
Expand All @@ -38,6 +41,36 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC

results := []FuncCallResults{}

returnResult := func(s string) (name, arguments string, e error) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s = utils.EscapeNewLines(s)
err := json.Unmarshal([]byte(s), &ss)
if err != nil {
log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
}
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss[functionNameKey]
if !ok {
return "", "", fmt.Errorf("unable to find function name in result")
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return "", "", fmt.Errorf("unable to find arguments in result")
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return "", "", fmt.Errorf("unable to cast function name to string")
}

return funcName, string(d), nil
}

// if no grammar is used, we have to extract function and arguments from the result
if !useGrammars {
// the response is a string that we have to parse
Expand All @@ -61,13 +94,32 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
if functionName == "" {
return results
}
} else if functionConfig.JSONRegexMatch != "" {
//re := regexp.MustCompile(`(?s)<tool_call>(.*?)</tool_call>`)
//m:= re.FindStringSubmatch(`<tool_call>{ foo barr }</tool_call>`)

// We use a regex to extract the JSON object from the response
var respRegex = regexp.MustCompile(functionConfig.JSONRegexMatch)
match := respRegex.FindStringSubmatch(llmresult)
if len(match) < 2 {
return results
}

funcName, args, err := returnResult(match[1])
if err != nil {
return results
}

return append(results, FuncCallResults{Name: funcName, Arguments: args})

} else {
// We expect the result to be a JSON object with a function name and arguments
err := json.Unmarshal([]byte(llmresult), &result)

funcName, args, err := returnResult(llmresult)
if err != nil {
log.Error().Err(err).Str("llmresult", llmresult).Msg("unable to unmarshal llm result")
return results
}

return append(results, FuncCallResults{Name: funcName, Arguments: args})
}

return append(results, FuncCallResults{Name: result[functionNameKey], Arguments: result["arguments"]})
Expand Down Expand Up @@ -101,32 +153,12 @@ func ParseFunctionCall(llmresult string, functionConfig FunctionsConfig) []FuncC
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
err := json.Unmarshal([]byte(s), &ss)
funcName, args, err := returnResult(llmresult)
if err != nil {
log.Error().Err(err).Str("escapedLLMResult", s).Msg("unable to unmarshal llm result")
}
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name, ok := ss[functionNameKey]
if !ok {
return results
}
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args, ok := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
if !ok {
return results
}
d, _ := json.Marshal(args)
funcName, ok := func_name.(string)
if !ok {
return results
}
results = append(results, FuncCallResults{Name: funcName, Arguments: string(d)})

results = append(results, FuncCallResults{Name: funcName, Arguments: args})
}

return results
Expand Down
37 changes: 35 additions & 2 deletions pkg/functions/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the name key", func() {
input := `{"name": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = true

Expand All @@ -100,7 +100,40 @@ var _ = Describe("LocalAI function parse tests", func() {
It("should parse the function name and arguments correctly with the function key", func() {
input := `{"function": "add", "arguments": {"x": 5, "y": 3}}`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = false
functionConfig.NoGrammar = true
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("add"))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
})

It("Should parse the result by matching the JSONRegexMatch", func() {
input := `
<tool_call>
{"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.JSONRegexMatch = `(?s)<tool_call>(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

results := ParseFunctionCall(input, functionConfig)
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("add"))
Expect(results[0].Arguments).To(Equal(`{"x":5,"y":3}`))
})

It("Should parse the result by matching the JSONRegexMatch", func() {
input := `
{"function": "add", "arguments": {"x": 5, "y": 3}}
</tool_call>`
functionConfig.ParallelCalls = false
functionConfig.NoGrammar = true
functionConfig.JSONRegexMatch = `(?s)(.*?)</tool_call>`
functionConfig.ResponseRegex = ""
functionConfig.FunctionName = false

Expand Down