Skip to content

Commit

Permalink
Add support for vision models. (#189)
Browse files Browse the repository at this point in the history
* Add support for image attachments.

* Fix RHS file upload issue

* Deal with deprication later

* mod tidy
  • Loading branch information
crspeller authored May 28, 2024
1 parent 36f079a commit 967cf16
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 39 deletions.
1 change: 0 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ linters:
- goimports
- gosec
- gosimple
- govet
- ineffassign
- misspell
- nakedret
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/jmoiron/sqlx v1.3.5
github.com/mattermost/mattermost/server/public v0.0.8
github.com/r3labs/sse/v2 v2.10.0
github.com/sashabaranov/go-openai v1.14.1
github.com/sashabaranov/go-openai v1.24.0
github.com/stretchr/testify v1.8.4
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktE
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sashabaranov/go-openai v1.14.1 h1:jqfkdj8XHnBF84oi2aNtT8Ktp3EJ0MfuVjvcMkfI0LA=
github.com/sashabaranov/go-openai v1.14.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.24.0 h1:4H4Pg8Bl2RH/YSnU8DYumZbuHnnkfioor/dtNlB20D4=
github.com/sashabaranov/go-openai v1.24.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
Expand Down
1 change: 1 addition & 0 deletions server/ai/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type BotConfig struct {
DisplayName string `json:"displayName"`
CustomInstructions string `json:"customInstructions"`
Service ServiceConfig `json:"service"`
EnableVision bool `json:"enableVision"`
}

func (c *BotConfig) IsValid() bool {
Expand Down
30 changes: 10 additions & 20 deletions server/ai/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ai
import (
"encoding/json"
"fmt"
"io"
"slices"
"strings"
"time"
Expand All @@ -20,9 +21,16 @@ const (
PostRoleSystem
)

type File struct {
MimeType string
Size int64
Reader io.Reader
}

type Post struct {
Role PostRole
Message string
Files []File
}

type ConversationContext struct {
Expand Down Expand Up @@ -99,11 +107,8 @@ type BotConversation struct {
Context ConversationContext
}

func (b *BotConversation) AddUserPost(post *model.Post) {
b.Posts = append(b.Posts, Post{
Role: PostRoleUser,
Message: FormatPostBody(post),
})
func (b *BotConversation) AddPost(post Post) {
b.Posts = append(b.Posts, post)
}

func (b *BotConversation) AppendConversation(conversation BotConversation) {
Expand Down Expand Up @@ -181,21 +186,6 @@ func GetPostRole(botID string, post *model.Post) PostRole {
return PostRoleUser
}

func ThreadToBotConversation(botID string, posts []*model.Post) BotConversation {
result := BotConversation{
Posts: make([]Post, 0, len(posts)),
}

for _, post := range posts {
result.Posts = append(result.Posts, Post{
Role: GetPostRole(botID, post),
Message: FormatPostBody(post),
})
}

return result
}

func FormatPostBody(post *model.Post) string {
attachments := post.Attachments()
if len(attachments) > 0 {
Expand Down
1 change: 1 addition & 0 deletions server/ai/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type LLMConfig struct {
Model string
MaxGeneratedTokens int
EnableVision bool
}

type LanguageModelOption func(*LLMConfig)
Expand Down
58 changes: 53 additions & 5 deletions server/ai/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const StreamingTimeoutDefault = 10 * time.Second

const MaxFunctionCalls = 10

const OpenAIMaxImageSize = 20 * 1024 * 1024 // 20 MB

var ErrStreamingTimeout = errors.New("timeout streaming")

func NewCompatible(llmService ai.ServiceConfig) *OpenAI {
Expand Down Expand Up @@ -82,7 +84,7 @@ func New(llmService ai.ServiceConfig) *OpenAI {

func modifyCompletionRequestWithConversation(request openaiClient.ChatCompletionRequest, conversation ai.BotConversation) openaiClient.ChatCompletionRequest {
request.Messages = postsToChatCompletionMessages(conversation.Posts)
request.Functions = toolsToFunctionDefinitions(conversation.Tools.GetTools())
request.Functions = toolsToFunctionDefinitions(conversation.Tools.GetTools()) //nolint:all
return request
}

Expand Down Expand Up @@ -116,10 +118,55 @@ func postsToChatCompletionMessages(posts []ai.Post) []openaiClient.ChatCompletio
} else if post.Role == ai.PostRoleSystem {
role = openaiClient.ChatMessageRoleSystem
}
result = append(result, openaiClient.ChatCompletionMessage{
Role: role,
Content: post.Message,
})
completionMessage := openaiClient.ChatCompletionMessage{
Role: role,
}

if len(post.Files) > 0 {
completionMessage.MultiContent = make([]openaiClient.ChatMessagePart, 0, len(post.Files)+1)
if post.Message != "" {
completionMessage.MultiContent = append(completionMessage.MultiContent, openaiClient.ChatMessagePart{
Type: openaiClient.ChatMessagePartTypeText,
Text: post.Message,
})
}
for _, file := range post.Files {
if file.MimeType != "image/png" &&
file.MimeType != "image/jpeg" &&
file.MimeType != "image/gif" &&
file.MimeType != "image/webp" {
completionMessage.MultiContent = append(completionMessage.MultiContent, openaiClient.ChatMessagePart{
Type: openaiClient.ChatMessagePartTypeText,
Text: "User submitted image was not a supported format. Tell the user this.",
})
continue
}
if file.Size > OpenAIMaxImageSize {
completionMessage.MultiContent = append(completionMessage.MultiContent, openaiClient.ChatMessagePart{
Type: openaiClient.ChatMessagePartTypeText,
Text: "User submitted a image larger than 20MB. Tell the user this.",
})
continue
}
fileBytes, err := io.ReadAll(file.Reader)
if err != nil {
continue
}
imageEncoded := base64.StdEncoding.EncodeToString(fileBytes)
encodedString := fmt.Sprintf("data:"+file.MimeType+";base64,%s", imageEncoded)
completionMessage.MultiContent = append(completionMessage.MultiContent, openaiClient.ChatMessagePart{
Type: openaiClient.ChatMessagePartTypeImageURL,
ImageURL: &openaiClient.ChatMessageImageURL{
URL: encodedString,
Detail: openaiClient.ImageURLDetailAuto,
},
})
}
} else {
completionMessage.Content = post.Message
}

result = append(result, completionMessage)
}

return result
Expand Down Expand Up @@ -281,6 +328,7 @@ func (s *OpenAI) createConfig(opts []ai.LanguageModelOption) ai.LLMConfig {
for _, opt := range opts {
opt(&cfg)
}

return cfg
}

Expand Down
42 changes: 42 additions & 0 deletions server/post_processing.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,45 @@ type WorkerResult struct {
return nil
}*/

func (p *Plugin) PostToAIPost(bot *Bot, post *model.Post) ai.Post {
var files []ai.File
if bot.cfg.EnableVision {
files = make([]ai.File, 0, len(post.FileIds))
for _, fileID := range post.FileIds {
fileInfo, err := p.pluginAPI.File.GetInfo(fileID)
if err != nil {
p.API.LogError("Error getting file info", "error", err)
continue
}
file, err := p.pluginAPI.File.Get(fileID)
if err != nil {
p.API.LogError("Error getting file", "error", err)
continue
}
files = append(files, ai.File{
Reader: file,
MimeType: fileInfo.MimeType,
Size: fileInfo.Size,
})
}
}

return ai.Post{
Role: ai.GetPostRole(bot.mmBot.UserId, post),
Message: ai.FormatPostBody(post),
Files: files,
}
}

func (p *Plugin) ThreadToBotConversation(bot *Bot, posts []*model.Post) ai.BotConversation {
result := ai.BotConversation{
Posts: make([]ai.Post, 0, len(posts)),
}

for _, post := range posts {
result.Posts = append(result.Posts, p.PostToAIPost(bot, post))
}

return result
}
6 changes: 3 additions & 3 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (p *Plugin) newConversation(bot *Bot, context ai.ConversationContext) error
if err != nil {
return err
}
conversation.AddUserPost(context.Post)
conversation.AddPost(p.PostToAIPost(bot, context.Post))

result, err := p.getLLM(bot.cfg.Service).ChatCompletion(conversation)
if err != nil {
Expand Down Expand Up @@ -131,7 +131,7 @@ func (p *Plugin) continueConversation(bot *Bot, threadData *ThreadData, context
if err != nil {
return nil, err
}
prompt.AppendConversation(ai.ThreadToBotConversation(bot.mmBot.UserId, threadData.Posts))
prompt.AppendConversation(p.ThreadToBotConversation(bot, threadData.Posts))

result, err = p.getLLM(bot.cfg.Service).ChatCompletion(prompt)
if err != nil {
Expand All @@ -154,7 +154,7 @@ func (p *Plugin) continueThreadConversation(bot *Bot, questionThreadData *Thread
if err != nil {
return nil, err
}
prompt.AppendConversation(ai.ThreadToBotConversation(bot.mmBot.UserId, questionThreadData.Posts))
prompt.AppendConversation(p.ThreadToBotConversation(bot, questionThreadData.Posts))

result, err := p.getLLM(bot.cfg.Service).ChatCompletion(prompt)
if err != nil {
Expand Down
37 changes: 33 additions & 4 deletions webapp/src/components/rhs/rhs_new_tab.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React from 'react';
import React, {useState} from 'react';
import styled from 'styled-components';

import {
Expand All @@ -7,6 +7,8 @@ import {
PlaylistCheckIcon,
} from '@mattermost/compass-icons/components';

import {useDispatch} from 'react-redux';

import RHSImage from '../assets/rhs_image';

import {createPost} from '@/client';
Expand Down Expand Up @@ -106,6 +108,8 @@ const addProsAndCons = () => {
};

const RHSNewTab = ({botChannelId, selectPost, setCurrentTab}: Props) => {
const dispatch = useDispatch();
const [draft, updateDraft] = useState<any>(null);
return (
<NewQuestion>
<RHSImage/>
Expand All @@ -122,12 +126,37 @@ const RHSNewTab = ({botChannelId, selectPost, setCurrentTab}: Props) => {
data-testid='rhs-new-tab-create-post'
channelId={botChannelId}
placeholder={'Ask AI Copilot anything...'}
rootId={'ai_copilot'}
onSubmit={async (p: any) => {
p.channel_id = botChannelId || '';
p.props = {};
const created = await createPost(p);
const post = {...p};
post.channel_id = botChannelId || '';
post.props = {};
post.uploadsInProgress = [];
post.file_ids = p.fileInfos.map((f: any) => f.id);
const created = await createPost(post);
selectPost(created.id);
setCurrentTab('thread');
dispatch({
type: 'SET_GLOBAL_ITEM',
data: {
name: 'comment_draft_ai_copilot',
value: {message: '', fileInfos: [], uploadsInProgress: []},
},
});
}}
draft={draft}
onUpdateCommentDraft={(newDraft: any) => {
updateDraft(newDraft);
const timestamp = new Date().getTime();
newDraft.updateAt = timestamp;
newDraft.createAt = newDraft.createAt || timestamp;
dispatch({
type: 'SET_GLOBAL_ITEM',
data: {
name: 'comment_draft_ai_copilot',
value: newDraft,
},
});
}}
/>
</CreatePostContainer>
Expand Down
12 changes: 10 additions & 2 deletions webapp/src/components/system_console/bot.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {DangerPill} from '../pill';

import {ButtonIcon} from '../assets/buttons';

import {ItemList, SelectionItem, SelectionItemOption, TextItem} from './item';
import {BooleanItem, ItemList, SelectionItem, SelectionItemOption, TextItem} from './item';
import AvatarItem from './avatar';

export type LLMService = {
Expand All @@ -29,6 +29,7 @@ export type LLMBotConfig = {
displayName: string
service: LLMService
customInstructions: string
enableVision: boolean
}

type Props = {
Expand Down Expand Up @@ -122,7 +123,14 @@ const Bot = (props: Props) => {
value={props.bot.customInstructions}
onChange={(e) => props.onChange({...props.bot, customInstructions: e.target.value})}
/>

{ (props.bot.service.type === 'openai' || props.bot.service.type === 'openaicompatible') && (
<BooleanItem
label='Enable Vision'
value={props.bot.enableVision}
onChange={(to: boolean) => props.onChange({...props.bot, enableVision: to})}
helpText='Enable Vision to allow the bot to process images. Requires a compatible model.'
/>
)}
</ItemList>
</ItemListContainer>
)}
Expand Down
4 changes: 3 additions & 1 deletion webapp/src/components/system_console/bots.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const defaultNewBot = {
tokenLimit: 0,
streamingTimeoutSeconds: 0,
},
enableVision: false,
};

export const firstNewBot = {
Expand All @@ -44,7 +45,8 @@ const Bots = (props: Props) => {
const multiLLMLicensed = useIsMultiLLMLicensed();
const licenceAddDisabled = !multiLLMLicensed && props.bots.length > 0;

const addNewBot = () => {
const addNewBot = (e: React.MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
const id = Math.random().toString(36).substring(2, 22);
if (props.bots.length === 0) {
// Suggest the '@ai' and 'AI Copilot' name for the first bot
Expand Down
Loading

0 comments on commit 967cf16

Please sign in to comment.