Skip to content

Commit

Permalink
feat: implement -u, --user-instead-of-system: Use the user role inste…
Browse files Browse the repository at this point in the history
…ad of the system role for the pattern. It is needed for Open AI o1 models for now.
  • Loading branch information
eugeis committed Sep 15, 2024
1 parent 19a0b8a commit 329c843
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 55 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ Application Options:
-T, --topp= Set top P (default: 0.9)
-s, --stream Stream
-P, --presencepenalty= Set presence penalty (default: 0.0)
-u, --user-instead-of-system Use the user role instead of the system role for the pattern
-F, --frequencypenalty= Set frequency penalty (default: 0.0)
-l, --listpatterns List all patterns
-L, --listmodels List all available models
Expand Down
10 changes: 6 additions & 4 deletions cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Flags struct {
TopP float64 `short:"T" long:"topp" description:"Set top P" default:"0.9"`
Stream bool `short:"s" long:"stream" description:"Stream"`
PresencePenalty float64 `short:"P" long:"presencepenalty" description:"Set presence penalty" default:"0.0"`
UserInsteadOfSystemRole bool `short:"u" long:"user-instead-of-system" description:"Use the user role instead of the system role for the pattern"`
FrequencyPenalty float64 `short:"F" long:"frequencypenalty" description:"Set frequency penalty" default:"0.0"`
ListPatterns bool `short:"l" long:"listpatterns" description:"List all patterns"`
ListAllModels bool `short:"L" long:"listmodels" description:"List all available models"`
Expand Down Expand Up @@ -89,10 +90,11 @@ func readStdin() (string, error) {

func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
ret = &common.ChatOptions{
Temperature: o.Temperature,
TopP: o.TopP,
PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty,
Temperature: o.Temperature,
TopP: o.TopP,
PresencePenalty: o.PresencePenalty,
FrequencyPenalty: o.FrequencyPenalty,
UserInsteadOfSystemRole: o.UserInsteadOfSystemRole,
}
return
}
Expand Down
9 changes: 5 additions & 4 deletions cli/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ func TestBuildChatOptions(t *testing.T) {
}

expectedOptions := &common.ChatOptions{
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
Temperature: 0.8,
TopP: 0.9,
PresencePenalty: 0.1,
FrequencyPenalty: 0.2,
UserInsteadOfSystemRole: false,
}
options := flags.BuildChatOptions()
assert.Equal(t, expectedOptions, options)
Expand Down
17 changes: 10 additions & 7 deletions common/domain.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package common

import goopenai "github.com/sashabaranov/go-openai"

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Expand All @@ -14,11 +16,12 @@ type ChatRequest struct {
}

type ChatOptions struct {
Model string
Temperature float64
TopP float64
PresencePenalty float64
FrequencyPenalty float64
Model string
Temperature float64
TopP float64
PresencePenalty float64
FrequencyPenalty float64
UserInsteadOfSystemRole bool
}

// NormalizeMessages remove empty messages and ensure messages order user-assist-user
Expand All @@ -32,8 +35,8 @@ func NormalizeMessages(msgs []*Message, defaultUserMessage string) (ret []*Messa
}

// Ensure, that each odd position shall be a user message
if fullMessageIndex%2 == 0 && message.Role != "user" {
ret = append(ret, &Message{Role: "user", Content: defaultUserMessage})
if fullMessageIndex%2 == 0 && message.Role != goopenai.ChatMessageRoleUser {
ret = append(ret, &Message{Role: goopenai.ChatMessageRoleUser, Content: defaultUserMessage})
fullMessageIndex++
}
ret = append(ret, message)
Expand Down
17 changes: 9 additions & 8 deletions common/domain_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package common

import (
goopenai "github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
"testing"
)

func TestNormalizeMessages(t *testing.T) {
msgs := []*Message{
{Role: "user", Content: "Hello"},
{Role: "bot", Content: "Hi there!"},
{Role: "bot", Content: ""},
{Role: "user", Content: ""},
{Role: "user", Content: "How are you?"},
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: goopenai.ChatMessageRoleUser, Content: ""},
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
}

expected := []*Message{
{Role: "user", Content: "Hello"},
{Role: "bot", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
{Role: goopenai.ChatMessageRoleUser, Content: "Hello"},
{Role: goopenai.ChatMessageRoleAssistant, Content: "Hi there!"},
{Role: goopenai.ChatMessageRoleUser, Content: "How are you?"},
}

actual := NormalizeMessages(msgs, "default")
Expand Down
6 changes: 3 additions & 3 deletions core/chatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package core
import (
"context"
"fmt"

"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/db"
"github.com/danielmiessler/fabric/vendors"
goopenai "github.com/sashabaranov/go-openai"
)

type Chatter struct {
Expand All @@ -26,7 +26,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
}

var session *db.Session
if session, err = chatRequest.BuildChatSession(); err != nil {
if session, err = chatRequest.BuildChatSession(opts.UserInsteadOfSystemRole); err != nil {
return
}

Expand All @@ -53,7 +53,7 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (m
}

if chatRequest.Session != nil && message != "" {
chatRequest.Session.Append(&common.Message{Role: "system", Content: message})
chatRequest.Session.Append(&common.Message{Role: goopenai.ChatMessageRoleAssistant, Content: message})
err = o.db.Sessions.SaveSession(chatRequest.Session)
}
return
Expand Down
2 changes: 1 addition & 1 deletion core/chatter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestBuildChatSession(t *testing.T) {
Pattern: "test pattern",
Message: "test message",
}
session, err := chat.BuildChatSession()
session, err := chat.BuildChatSession(false)
if err != nil {
t.Fatalf("BuildChatSession() error = %v", err)
}
Expand Down
23 changes: 15 additions & 8 deletions core/fabric.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"github.com/danielmiessler/fabric/vendors/groq"
goopenai "github.com/sashabaranov/go-openai"
"os"
"strconv"
"strings"
Expand Down Expand Up @@ -236,7 +237,7 @@ func (o *Fabric) CreateOutputFile(message string, fileName string) (err error) {
return
}

func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
func (o *Chat) BuildChatSession(userInsteadOfSystemRole bool) (ret *db.Session, err error) {
// new messages will be appended to the session and used to send the message
if o.Session != nil {
ret = o.Session
Expand All @@ -245,14 +246,20 @@ func (o *Chat) BuildChatSession() (ret *db.Session, err error) {
}

systemMessage := strings.TrimSpace(o.Context) + strings.TrimSpace(o.Pattern)

if systemMessage != "" {
ret.Append(&common.Message{Role: "system", Content: systemMessage})
}

userMessage := strings.TrimSpace(o.Message)
if userMessage != "" {
ret.Append(&common.Message{Role: "user", Content: userMessage})

if userInsteadOfSystemRole {
message := systemMessage + userMessage
if message != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: message})
}
} else {
if systemMessage != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleSystem, Content: systemMessage})
}
if userMessage != "" {
ret.Append(&common.Message{Role: goopenai.ChatMessageRoleUser, Content: userMessage})
}
}

if ret.IsEmpty() {
Expand Down
5 changes: 2 additions & 3 deletions vendors/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
goopenai "github.com/sashabaranov/go-openai"

"github.com/danielmiessler/fabric/common"
"github.com/liushuangls/go-anthropic/v2"
Expand Down Expand Up @@ -121,10 +122,8 @@ func (an *Client) toMessages(msgs []*common.Message) (ret []anthropic.Message) {
for _, msg := range normalizedMessages {
var message anthropic.Message
switch msg.Role {
case "user":
case goopenai.ChatMessageRoleUser:
message = anthropic.NewUserTextMessage(msg.Content)
case "system":
message = anthropic.NewAssistantTextMessage(msg.Content)
default:
message = anthropic.NewAssistantTextMessage(msg.Content)
}
Expand Down
17 changes: 11 additions & 6 deletions vendors/dryrun/dryrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
goopenai "github.com/sashabaranov/go-openai"

"github.com/danielmiessler/fabric/common"
)
Expand Down Expand Up @@ -35,9 +36,11 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch

for _, msg := range msgs {
switch msg.Role {
case "system":
case goopenai.ChatMessageRoleSystem:
output += fmt.Sprintf("System:\n%s\n\n", msg.Content)
case "user":
case goopenai.ChatMessageRoleAssistant:
output += fmt.Sprintf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
output += fmt.Sprintf("User:\n%s\n\n", msg.Content)
default:
output += fmt.Sprintf("%s:\n%s\n\n", msg.Role, msg.Content)
Expand All @@ -56,14 +59,16 @@ func (c *Client) SendStream(msgs []*common.Message, opts *common.ChatOptions, ch
return nil
}

func (c *Client) Send(ctx context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
func (c *Client) Send(_ context.Context, msgs []*common.Message, opts *common.ChatOptions) (string, error) {
fmt.Println("Dry run: Would send the following request:")

for _, msg := range msgs {
switch msg.Role {
case "system":
case goopenai.ChatMessageRoleSystem:
fmt.Printf("System:\n%s\n\n", msg.Content)
case "user":
case goopenai.ChatMessageRoleAssistant:
fmt.Printf("Assistant:\n%s\n\n", msg.Content)
case goopenai.ChatMessageRoleUser:
fmt.Printf("User:\n%s\n\n", msg.Content)
default:
fmt.Printf("%s:\n%s\n\n", msg.Role, msg.Content)
Expand All @@ -84,6 +89,6 @@ func (c *Client) Setup() error {
return nil
}

func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
func (c *Client) SetupFillEnvFileContent(_ *bytes.Buffer) {
// No environment variables needed for dry run
}
12 changes: 1 addition & 11 deletions vendors/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,7 @@ func (o *Client) buildChatCompletionRequest(
msgs []*common.Message, opts *common.ChatOptions,
) (ret goopenai.ChatCompletionRequest) {
messages := lo.Map(msgs, func(message *common.Message, _ int) goopenai.ChatCompletionMessage {
var role string

switch message.Role {
case "user":
role = goopenai.ChatMessageRoleUser
case "system":
role = goopenai.ChatMessageRoleSystem
default:
role = goopenai.ChatMessageRoleSystem
}
return goopenai.ChatCompletionMessage{Role: role, Content: message.Content}
return goopenai.ChatCompletionMessage{Role: message.Role, Content: message.Content}
})

ret = goopenai.ChatCompletionRequest{
Expand Down

0 comments on commit 329c843

Please sign in to comment.