Skip to content

Commit

Permalink
Prevent having to pass posts to services (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
crspeller authored Jun 14, 2024
1 parent 9ed0bb7 commit f0ef416
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 51 deletions.
16 changes: 7 additions & 9 deletions server/ai/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@ type Anthropic struct {
client *Client
defaultModel string
tokenLimit int
metricsService metrics.Metrics
name string
metricsService metrics.LLMetrics
}

func New(botConfig ai.BotConfig, metricsService metrics.Metrics) *Anthropic {
client := NewClient(botConfig.Service.APIKey)
func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *Anthropic {
client := NewClient(llmService.APIKey)

return &Anthropic{
client: client,
defaultModel: botConfig.Service.DefaultModel,
tokenLimit: botConfig.Service.TokenLimit,
defaultModel: llmService.DefaultModel,
tokenLimit: llmService.TokenLimit,
metricsService: metricsService,
name: botConfig.Name,
}
}

Expand Down Expand Up @@ -84,7 +82,7 @@ func (a *Anthropic) createCompletionRequest(conversation ai.BotConversation, opt
}

func (a *Anthropic) ChatCompletion(conversation ai.BotConversation, opts ...ai.LanguageModelOption) (*ai.TextStreamResult, error) {
a.metricsService.IncrementLLMRequests(a.name)
a.metricsService.IncrementLLMRequests()

request := a.createCompletionRequest(conversation, opts)
request.Stream = true
Expand All @@ -97,7 +95,7 @@ func (a *Anthropic) ChatCompletion(conversation ai.BotConversation, opts ...ai.L
}

func (a *Anthropic) ChatCompletionNoStream(conversation ai.BotConversation, opts ...ai.LanguageModelOption) (string, error) {
a.metricsService.IncrementLLMRequests(a.name)
a.metricsService.IncrementLLMRequests()

request := a.createCompletionRequest(conversation, opts)
request.Stream = false
Expand Down
26 changes: 12 additions & 14 deletions server/ai/asksage/asksage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@ import (
)

type AskSage struct {
client *Client
defaultModel string
maxTokens int
metricsService metrics.Metrics
name string
client *Client
defaultModel string
maxTokens int
metric metrics.LLMetrics
}

func New(botConfig ai.BotConfig, metricsService metrics.Metrics) *AskSage {
func New(llmService ai.ServiceConfig, metric metrics.LLMetrics) *AskSage {
client := NewClient("")
client.Login(GetTokenParams{
Email: botConfig.Service.Username,
Password: botConfig.Service.Password,
Email: llmService.Username,
Password: llmService.Password,
})
return &AskSage{
client: client,
defaultModel: botConfig.Service.DefaultModel,
maxTokens: botConfig.Service.TokenLimit,
metricsService: metricsService,
name: botConfig.Name,
client: client,
defaultModel: llmService.DefaultModel,
maxTokens: llmService.TokenLimit,
metric: metric,
}
}

Expand Down Expand Up @@ -80,7 +78,7 @@ func (s *AskSage) ChatCompletion(conversation ai.BotConversation, opts ...ai.Lan
}

func (s *AskSage) ChatCompletionNoStream(conversation ai.BotConversation, opts ...ai.LanguageModelOption) (string, error) {
s.metricsService.IncrementLLMRequests(s.name)
s.metric.IncrementLLMRequests()

params := s.queryParamsFromConfig(s.createConfig(opts))
params.Message = conversationToMessagesList(conversation)
Expand Down
35 changes: 16 additions & 19 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ type OpenAI struct {
defaultModel string
tokenLimit int
streamingTimeout time.Duration
metricsService metrics.Metrics
name string
metricsService metrics.LLMetrics
}

const StreamingTimeoutDefault = 10 * time.Second
Expand All @@ -39,10 +38,10 @@ const OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB

var ErrStreamingTimeout = errors.New("timeout streaming")

func NewCompatible(botConfig ai.BotConfig, metricsService metrics.Metrics) *OpenAI {
apiKey := botConfig.Service.APIKey
endpointURL := strings.TrimSuffix(botConfig.Service.APIURL, "/")
defaultModel := botConfig.Service.DefaultModel
func NewCompatible(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI {
apiKey := llmService.APIKey
endpointURL := strings.TrimSuffix(llmService.APIURL, "/")
defaultModel := llmService.DefaultModel
config := openaiClient.DefaultConfig(apiKey)
config.BaseURL = endpointURL

Expand All @@ -53,39 +52,37 @@ func NewCompatible(botConfig ai.BotConfig, metricsService metrics.Metrics) *Open
}

streamingTimeout := StreamingTimeoutDefault
if botConfig.Service.StreamingTimeoutSeconds > 0 {
streamingTimeout = time.Duration(botConfig.Service.StreamingTimeoutSeconds) * time.Second
if llmService.StreamingTimeoutSeconds > 0 {
streamingTimeout = time.Duration(llmService.StreamingTimeoutSeconds) * time.Second
}
return &OpenAI{
client: openaiClient.NewClientWithConfig(config),
defaultModel: defaultModel,
tokenLimit: botConfig.Service.TokenLimit,
tokenLimit: llmService.TokenLimit,
streamingTimeout: streamingTimeout,
metricsService: metricsService,
name: botConfig.Name,
}
}

func New(botConfig ai.BotConfig, metricsService metrics.Metrics) *OpenAI {
defaultModel := botConfig.Service.DefaultModel
func New(llmService ai.ServiceConfig, metricsService metrics.LLMetrics) *OpenAI {
defaultModel := llmService.DefaultModel
if defaultModel == "" {
defaultModel = openaiClient.GPT3Dot5Turbo
}
config := openaiClient.DefaultConfig(botConfig.Service.APIKey)
config.OrgID = botConfig.Service.OrgID
config := openaiClient.DefaultConfig(llmService.APIKey)
config.OrgID = llmService.OrgID

streamingTimeout := StreamingTimeoutDefault
if botConfig.Service.StreamingTimeoutSeconds > 0 {
streamingTimeout = time.Duration(botConfig.Service.StreamingTimeoutSeconds) * time.Second
if llmService.StreamingTimeoutSeconds > 0 {
streamingTimeout = time.Duration(llmService.StreamingTimeoutSeconds) * time.Second
}

return &OpenAI{
client: openaiClient.NewClientWithConfig(config),
defaultModel: defaultModel,
tokenLimit: botConfig.Service.TokenLimit,
tokenLimit: llmService.TokenLimit,
streamingTimeout: streamingTimeout,
metricsService: metricsService,
name: botConfig.Name,
}
}

Expand Down Expand Up @@ -351,7 +348,7 @@ func (s *OpenAI) completionRequestFromConfig(cfg ai.LLMConfig) openaiClient.Chat
}

func (s *OpenAI) ChatCompletion(conversation ai.BotConversation, opts ...ai.LanguageModelOption) (*ai.TextStreamResult, error) {
s.metricsService.IncrementLLMRequests(s.name)
s.metricsService.IncrementLLMRequests()

request := s.completionRequestFromConfig(s.createConfig(opts))
request = modifyCompletionRequestWithConversation(request, conversation)
Expand Down
24 changes: 21 additions & 3 deletions server/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Metrics interface {
IncrementHTTPRequests()
IncrementHTTPErrors()

IncrementLLMRequests(llmName string)
GetMetricsForAIService(llmName string) *llmMetrics
}

type InstanceInfo struct {
Expand Down Expand Up @@ -152,8 +152,26 @@ func (m *metrics) IncrementHTTPErrors() {
}
}

func (m *metrics) IncrementLLMRequests(llmName string) {
func (m *metrics) GetMetricsForAIService(llmName string) *llmMetrics {
if m == nil {
return nil
}

return &llmMetrics{
llmRequestsTotal: m.llmRequestsTotal.MustCurryWith(prometheus.Labels{"llm_name": llmName}),
}
}

type LLMetrics interface {
IncrementLLMRequests()
}

type llmMetrics struct {
llmRequestsTotal *prometheus.CounterVec
}

func (m *llmMetrics) IncrementLLMRequests() {
if m != nil {
m.llmRequestsTotal.With(prometheus.Labels{"llm_name": llmName}).Inc()
m.llmRequestsTotal.With(prometheus.Labels{}).Inc()
}
}
15 changes: 9 additions & 6 deletions server/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,18 @@ func (p *Plugin) OnActivate() error {
}

func (p *Plugin) getLLM(llmBotConfig ai.BotConfig) ai.LanguageModel {
metrics := p.metricsService.GetMetricsForAIService(llmBotConfig.Name)

var llm ai.LanguageModel
switch llmBotConfig.Service.Type {
case "openai":
llm = openai.New(llmBotConfig, p.metricsService)
llm = openai.New(llmBotConfig.Service, metrics)
case "openaicompatible":
llm = openai.NewCompatible(llmBotConfig, p.metricsService)
llm = openai.NewCompatible(llmBotConfig.Service, metrics)
case "anthropic":
llm = anthropic.New(llmBotConfig, p.metricsService)
llm = anthropic.New(llmBotConfig.Service, metrics)
case "asksage":
llm = asksage.New(llmBotConfig, p.metricsService)
llm = asksage.New(llmBotConfig.Service, metrics)
}

cfg := p.getConfiguration()
Expand All @@ -159,11 +161,12 @@ func (p *Plugin) getTranscribe() ai.Transcriber {
break
}
}
metrics := p.metricsService.GetMetricsForAIService(botConfig.Name)
switch botConfig.Service.Type {
case "openai":
return openai.New(botConfig, p.metricsService)
return openai.New(botConfig.Service, metrics)
case "openaicompatible":
return openai.NewCompatible(botConfig, p.metricsService)
return openai.NewCompatible(botConfig.Service, metrics)
}
return nil
}
Expand Down

0 comments on commit f0ef416

Please sign in to comment.