Skip to content

Commit

Permalink
Make the role configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Sep 15, 2023
1 parent 85b85d8 commit 1c073b1
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 20 deletions.
23 changes: 12 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,18 @@ values, the `config.yaml` file, and environment variables, in that respective or

Configuration variables:

| Variable | Description | Default |
|--------------------|-----------------------------------------------------------------------------------|--------------------------|
| `name` | The prefix for environment variable overrides. | 'openai' |
| `api_key` | Your OpenAI API key. | (none for security) |
| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' |
| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 |
| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' |
| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false |
| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' |
| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' |
| `models_path` | The API endpoint for accessing model information. | '/v1/models' |
| Variable | Description | Default |
|--------------------|-----------------------------------------------------------------------------------|--------------------------------|
| `name` | The prefix for environment variable overrides. | 'openai' |
| `api_key` | Your OpenAI API key. | (none for security) |
| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' |
| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 |
| `role` | The system role | 'You are a helpful assistant.' |
| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' |
| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false |
| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' |
| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' |
| `models_path` | The API endpoint for accessing model information. | '/v1/models' |

The defaults can be overridden by providing your own values in the user configuration file,
named `.chatgpt-cli/config.yaml`, located in your home directory.
Expand Down
6 changes: 3 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
)

const (
AssistantContent = "You are a helpful assistant."
AssistantRole = "assistant"
ErrEmptyResponse = "empty response"
MaxTokenBufferPercentage = 20
Expand Down Expand Up @@ -180,10 +179,11 @@ func (c *Client) initHistory() {

if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
Role: SystemRole,
}}
}

c.History[0].Content = c.Config.Role
}

func (c *Client) addQuery(query string) {
Expand Down
12 changes: 7 additions & 5 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
defaultCompletionsPath = "/default/completions"
defaultModelsPath = "/default/models"
defaultThread = "default-thread"
defaultRole = "You are a great default-role"
envApiKey = "api-key"
)

Expand Down Expand Up @@ -209,7 +210,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
history := []types.Message{
{
Role: client.SystemRole,
Content: client.AssistantContent,
Content: defaultRole,
},
{
Role: client.UserRole,
Expand Down Expand Up @@ -252,7 +253,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
history := []types.Message{
{
Role: client.SystemRole,
Content: client.AssistantContent,
Content: defaultRole,
},
{
Role: client.UserRole,
Expand Down Expand Up @@ -352,7 +353,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
history := []types.Message{
{
Role: client.SystemRole,
Content: client.AssistantContent,
Content: defaultRole,
},
{
Role: client.UserRole,
Expand Down Expand Up @@ -433,7 +434,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

systemMessage := subject.History[0]
Expect(systemMessage.Role).To(Equal(client.SystemRole))
Expect(systemMessage.Content).To(Equal("You are a helpful assistant."))
Expect(systemMessage.Content).To(Equal(defaultRole))

contextMessage := subject.History[1]
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expand All @@ -458,7 +459,7 @@ func createMessages(history []types.Message, query string) []types.Message {
if len(history) == 0 {
messages = append(messages, types.Message{
Role: client.SystemRole,
Content: client.AssistantContent,
Content: defaultRole,
})
} else {
messages = history
Expand Down Expand Up @@ -486,6 +487,7 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor
URL: defaultURL,
CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath,
Role: defaultRole,
Thread: defaultThread,
}).Times(1)

Expand Down
2 changes: 2 additions & 0 deletions config/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
openAIURL = "https://api.openai.com"
openAICompletionsPath = "/v1/chat/completions"
openAIModelsPath = "/v1/models"
openAIRole = "You are a helpful assistant."
openAIThread = "default"
)

Expand Down Expand Up @@ -51,6 +52,7 @@ func (f *FileIO) ReadDefaults() types.Config {
return types.Config{
Name: openAIName,
Model: openAIModel,
Role: openAIRole,
MaxTokens: openAIModelMaxTokens,
URL: openAIURL,
CompletionsPath: openAICompletionsPath,
Expand Down
65 changes: 65 additions & 0 deletions configmanager/configmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
defaultName = "default-name"
defaultURL = "default-url"
defaultModel = "default-model"
defaultRole = "default-role"
defaultApiKey = "default-api-key"
defaultThread = "default-thread"
defaultCompletionsPath = "default-completions-path"
Expand Down Expand Up @@ -55,6 +56,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
CompletionsPath: defaultCompletionsPath,
ModelsPath: defaultModelsPath,
OmitHistory: defaultOmitHistory,
Role: defaultRole,
Thread: defaultThread,
}

Expand Down Expand Up @@ -87,6 +89,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided model", func() {
Expand All @@ -105,6 +108,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided name", func() {
Expand All @@ -123,6 +127,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided max-tokens", func() {
Expand All @@ -141,6 +146,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))

})
Expand All @@ -160,6 +166,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided completions-path", func() {
Expand All @@ -178,6 +185,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided models-path", func() {
Expand All @@ -196,6 +204,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(modelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided api-key", func() {
Expand All @@ -214,6 +223,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(apiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the user provided omit-history", func() {
Expand All @@ -231,6 +241,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
Expect(subject.Config.OmitHistory).To(Equal(omitHistory))
})
Expand All @@ -250,8 +261,28 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(userThread))
})
it("gives precedence to the user provided role", func() {
userRole := "user-role"

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{Role: userRole}, nil).Times(1)

subject := configmanager.New(mockConfigStore).WithEnvironment()

Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Thread).To(Equal(defaultThread))
Expect(subject.Config.Role).To(Equal(userRole))
})
it("gives precedence to the OMIT_HISTORY environment variable", func() {
var (
environmentValue = true
Expand All @@ -272,6 +303,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
Expect(subject.Config.OmitHistory).To(Equal(environmentValue))
})
Expand All @@ -296,8 +328,33 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(environmentValue))
})
it("gives precedence to the ROLE environment variable", func() {
var (
environmentValue = "env-role"
configValue = "conf-role"
)

Expect(os.Setenv(envPrefix+"ROLE", environmentValue)).To(Succeed())

mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
mockConfigStore.EXPECT().Read().Return(types.Config{Role: configValue}, nil).Times(1)

subject := configmanager.New(mockConfigStore).WithEnvironment()

Expect(subject.Config.Name).To(Equal(defaultName))
Expect(subject.Config.Model).To(Equal(defaultModel))
Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens))
Expect(subject.Config.URL).To(Equal(defaultURL))
Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath))
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Thread).To(Equal(defaultThread))
Expect(subject.Config.Role).To(Equal(environmentValue))
})
it("gives precedence to the API_KEY environment variable", func() {
var (
environmentKey = "environment-api-key"
Expand All @@ -319,6 +376,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(environmentKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the MODEL environment variable", func() {
Expand All @@ -342,6 +400,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the MAX_TOKENS environment variable", func() {
Expand All @@ -365,6 +424,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the URL environment variable", func() {
Expand All @@ -388,6 +448,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the COMPLETIONS_PATH environment variable", func() {
Expand All @@ -411,6 +472,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
it("gives precedence to the MODELS_PATH environment variable", func() {
Expand All @@ -434,6 +496,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.ModelsPath).To(Equal(envModelsPath))
Expect(subject.Config.APIKey).To(Equal(defaultApiKey))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
})
Expand All @@ -455,6 +518,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(result).To(ContainSubstring(defaultModelsPath))
Expect(result).To(ContainSubstring(fmt.Sprintf("%d", defaultMaxTokens)))
Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory))
Expect(subject.Config.Role).To(Equal(defaultRole))
Expect(subject.Config.Thread).To(Equal(defaultThread))
})
})
Expand Down Expand Up @@ -499,4 +563,5 @@ func cleanEnv(envPrefix string) {
Expect(os.Unsetenv(envPrefix + "MODELS_PATH")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "OMIT_HISTORY")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "THREAD")).To(Succeed())
Expect(os.Unsetenv(envPrefix + "ROLE")).To(Succeed())
}
2 changes: 1 addition & 1 deletion integration/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) {
body := types.CompletionsRequest{
Messages: []types.Message{{
Role: client.SystemRole,
Content: client.AssistantContent,
Content: defaults.Role,
}},
Model: defaults.Model,
Stream: false,
Expand Down
1 change: 1 addition & 0 deletions types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type Config struct {
APIKey string `yaml:"api_key"`
Model string `yaml:"model"`
MaxTokens int `yaml:"max_tokens"`
Role string `yaml:"role"`
Thread string `yaml:"thread"`
OmitHistory bool `yaml:"omit_history"`
URL string `yaml:"url"`
Expand Down

0 comments on commit 1c073b1

Please sign in to comment.