diff --git a/README.md b/README.md index 254c171..5f3aa5e 100644 --- a/README.md +++ b/README.md @@ -162,17 +162,18 @@ values, the `config.yaml` file, and environment variables, in that respective or Configuration variables: -| Variable | Description | Default | -|--------------------|-----------------------------------------------------------------------------------|--------------------------| -| `name` | The prefix for environment variable overrides. | 'openai' | -| `api_key` | Your OpenAI API key. | (none for security) | -| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' | -| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 | -| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' | -| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false | -| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' | -| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' | -| `models_path` | The API endpoint for accessing model information. | '/v1/models' | +| Variable | Description | Default | +|--------------------|-----------------------------------------------------------------------------------|--------------------------------| +| `name` | The prefix for environment variable overrides. | 'openai' | +| `api_key` | Your OpenAI API key. | (none for security) | +| `model` | The GPT model used by the application. | 'gpt-3.5-turbo' | +| `max_tokens` | The maximum number of tokens that can be used in a single API call. | 4096 | +| `role` | The system role | 'You are a helpful assistant.' | +| `thread` | The name of the current chat thread. Each unique thread name has its own context. | 'default' | +| `omit_history` | If true, the chat history will not be used to provide context for the GPT model. | false | +| `url` | The base URL for the OpenAI API. | 'https://api.openai.com' | +| `completions_path` | The API endpoint for completions. | '/v1/chat/completions' | +| `models_path` | The API endpoint for accessing model information. | '/v1/models' | The defaults can be overridden by providing your own values in the user configuration file, named `.chatgpt-cli/config.yaml`, located in your home directory. diff --git a/client/client.go b/client/client.go index 15d7d3b..db96dc2 100644 --- a/client/client.go +++ b/client/client.go @@ -14,7 +14,6 @@ import ( ) const ( - AssistantContent = "You are a helpful assistant." AssistantRole = "assistant" ErrEmptyResponse = "empty response" MaxTokenBufferPercentage = 20 @@ -180,10 +179,11 @@ func (c *Client) initHistory() { if len(c.History) == 0 { c.History = []types.Message{{ - Role: SystemRole, - Content: AssistantContent, + Role: SystemRole, }} } + + c.History[0].Content = c.Config.Role } func (c *Client) addQuery(query string) { diff --git a/client/client_test.go b/client/client_test.go index 7f635a2..9d1eaf0 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -29,6 +29,7 @@ const ( defaultCompletionsPath = "/default/completions" defaultModelsPath = "/default/models" defaultThread = "default-thread" + defaultRole = "You are a great default-role" envApiKey = "api-key" ) @@ -209,7 +210,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { history := []types.Message{ { Role: client.SystemRole, - Content: client.AssistantContent, + Content: defaultRole, }, { Role: client.UserRole, @@ -252,7 +253,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { history := []types.Message{ { Role: client.SystemRole, - Content: client.AssistantContent, + Content: defaultRole, }, { Role: client.UserRole, @@ -352,7 +353,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { history := []types.Message{ { Role: client.SystemRole, - Content: client.AssistantContent, + Content: defaultRole, }, { Role: client.UserRole, @@ -433,7 +434,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) { systemMessage := subject.History[0] Expect(systemMessage.Role).To(Equal(client.SystemRole)) - Expect(systemMessage.Content).To(Equal("You are a helpful assistant.")) + Expect(systemMessage.Content).To(Equal(defaultRole)) contextMessage := subject.History[1] Expect(contextMessage.Role).To(Equal(client.UserRole)) @@ -458,7 +459,7 @@ func createMessages(history []types.Message, query string) []types.Message { if len(history) == 0 { messages = append(messages, types.Message{ Role: client.SystemRole, - Content: client.AssistantContent, + Content: defaultRole, }) } else { messages = history @@ -486,6 +487,7 @@ func newClientFactory(mc *MockCaller, mcs *MockConfigStore, mhs *MockHistoryStor URL: defaultURL, CompletionsPath: defaultCompletionsPath, ModelsPath: defaultModelsPath, + Role: defaultRole, Thread: defaultThread, }).Times(1) diff --git a/config/store.go b/config/store.go index 1e229d0..81091ed 100644 --- a/config/store.go +++ b/config/store.go @@ -15,6 +15,7 @@ const ( openAIURL = "https://api.openai.com" openAICompletionsPath = "/v1/chat/completions" openAIModelsPath = "/v1/models" + openAIRole = "You are a helpful assistant." openAIThread = "default" ) @@ -51,6 +52,7 @@ func (f *FileIO) ReadDefaults() types.Config { return types.Config{ Name: openAIName, Model: openAIModel, + Role: openAIRole, MaxTokens: openAIModelMaxTokens, URL: openAIURL, CompletionsPath: openAICompletionsPath, diff --git a/configmanager/configmanager_test.go b/configmanager/configmanager_test.go index 1e82d55..e360098 100644 --- a/configmanager/configmanager_test.go +++ b/configmanager/configmanager_test.go @@ -27,6 +27,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { defaultName = "default-name" defaultURL = "default-url" defaultModel = "default-model" + defaultRole = "default-role" defaultApiKey = "default-api-key" defaultThread = "default-thread" defaultCompletionsPath = "default-completions-path" @@ -55,6 +56,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { CompletionsPath: defaultCompletionsPath, ModelsPath: defaultModelsPath, OmitHistory: defaultOmitHistory, + Role: defaultRole, Thread: defaultThread, } @@ -87,6 +89,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.Name).To(Equal(defaultName)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided model", func() { @@ -105,6 +108,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.Name).To(Equal(defaultName)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided name", func() { @@ -123,6 +127,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided max-tokens", func() { @@ -141,6 +146,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) @@ -160,6 +166,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided completions-path", func() { @@ -178,6 +185,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided models-path", func() { @@ -196,6 +204,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(modelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided api-key", func() { @@ -214,6 +223,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(apiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the user provided omit-history", func() { @@ -231,6 +241,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) Expect(subject.Config.OmitHistory).To(Equal(omitHistory)) }) @@ -250,8 +261,28 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(userThread)) }) + it("gives precedence to the user provided role", func() { + userRole := "user-role" + + mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) + mockConfigStore.EXPECT().Read().Return(types.Config{Role: userRole}, nil).Times(1) + + subject := configmanager.New(mockConfigStore).WithEnvironment() + + Expect(subject.Config.Name).To(Equal(defaultName)) + Expect(subject.Config.Model).To(Equal(defaultModel)) + Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) + Expect(subject.Config.URL).To(Equal(defaultURL)) + Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) + Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) + Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) + Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Thread).To(Equal(defaultThread)) + Expect(subject.Config.Role).To(Equal(userRole)) + }) it("gives precedence to the OMIT_HISTORY environment variable", func() { var ( environmentValue = true @@ -272,6 +303,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) Expect(subject.Config.OmitHistory).To(Equal(environmentValue)) }) @@ -296,8 +328,33 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(environmentValue)) }) + it("gives precedence to the ROLE environment variable", func() { + var ( + environmentValue = "env-role" + configValue = "conf-role" + ) + + Expect(os.Setenv(envPrefix+"ROLE", environmentValue)).To(Succeed()) + + mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) + mockConfigStore.EXPECT().Read().Return(types.Config{Role: configValue}, nil).Times(1) + + subject := configmanager.New(mockConfigStore).WithEnvironment() + + Expect(subject.Config.Name).To(Equal(defaultName)) + Expect(subject.Config.Model).To(Equal(defaultModel)) + Expect(subject.Config.MaxTokens).To(Equal(defaultMaxTokens)) + Expect(subject.Config.URL).To(Equal(defaultURL)) + Expect(subject.Config.CompletionsPath).To(Equal(defaultCompletionsPath)) + Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) + Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) + Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Thread).To(Equal(defaultThread)) + Expect(subject.Config.Role).To(Equal(environmentValue)) + }) it("gives precedence to the API_KEY environment variable", func() { var ( environmentKey = "environment-api-key" @@ -319,6 +376,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(environmentKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the MODEL environment variable", func() { @@ -342,6 +400,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the MAX_TOKENS environment variable", func() { @@ -365,6 +424,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the URL environment variable", func() { @@ -388,6 +448,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the COMPLETIONS_PATH environment variable", func() { @@ -411,6 +472,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(defaultModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) it("gives precedence to the MODELS_PATH environment variable", func() { @@ -434,6 +496,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.ModelsPath).To(Equal(envModelsPath)) Expect(subject.Config.APIKey).To(Equal(defaultApiKey)) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) }) @@ -455,6 +518,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(result).To(ContainSubstring(defaultModelsPath)) Expect(result).To(ContainSubstring(fmt.Sprintf("%d", defaultMaxTokens))) Expect(subject.Config.OmitHistory).To(Equal(defaultOmitHistory)) + Expect(subject.Config.Role).To(Equal(defaultRole)) Expect(subject.Config.Thread).To(Equal(defaultThread)) }) }) @@ -499,4 +563,5 @@ func cleanEnv(envPrefix string) { Expect(os.Unsetenv(envPrefix + "MODELS_PATH")).To(Succeed()) Expect(os.Unsetenv(envPrefix + "OMIT_HISTORY")).To(Succeed()) Expect(os.Unsetenv(envPrefix + "THREAD")).To(Succeed()) + Expect(os.Unsetenv(envPrefix + "ROLE")).To(Succeed()) } diff --git a/integration/contract_test.go b/integration/contract_test.go index 37a70e7..6384f93 100644 --- a/integration/contract_test.go +++ b/integration/contract_test.go @@ -41,7 +41,7 @@ func testContract(t *testing.T, when spec.G, it spec.S) { body := types.CompletionsRequest{ Messages: []types.Message{{ Role: client.SystemRole, - Content: client.AssistantContent, + Content: defaults.Role, }}, Model: defaults.Model, Stream: false, diff --git a/types/config.go b/types/config.go index 75e78e2..7fd38a8 100644 --- a/types/config.go +++ b/types/config.go @@ -5,6 +5,7 @@ type Config struct { APIKey string `yaml:"api_key"` Model string `yaml:"model"` MaxTokens int `yaml:"max_tokens"` + Role string `yaml:"role"` Thread string `yaml:"thread"` OmitHistory bool `yaml:"omit_history"` URL string `yaml:"url"`