Skip to content

Commit

Permalink
feat(ui): improve model combobox menu UI/UX by extracting the provide…
Browse files Browse the repository at this point in the history
…rs into standalone subgroups (#681)

* feat: move models into sub-menus

* feat: improve chat model combobox menu UI

---------

Co-authored-by: GenericMale <[email protected]>
Co-authored-by: Carl-Robert Linnupuu <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2024
1 parent c399e00 commit 2d30472
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTModel;
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings;
import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings;
import ee.carlrobert.codegpt.settings.service.google.GoogleSettings;
import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings;
import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings;
import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings;
import ee.carlrobert.llm.client.google.models.GoogleModel;
import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -93,65 +95,67 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta
if (availableProviders.contains(CODEGPT)) {
actionGroup.addSeparator("CodeGPT");
actionGroup.addAll(getCodeGPTModelActions(project, presentation));
actionGroup.addSeparator();
}
actionGroup.addSeparator("Cloud Providers");
if (availableProviders.contains(OPENAI)) {
actionGroup.addSeparator("OpenAI");
var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI");
openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI);
List.of(
OpenAIChatCompletionModel.GPT_4_O,
OpenAIChatCompletionModel.GPT_4_O_MINI,
OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW,
OpenAIChatCompletionModel.GPT_4_0125_128k)
.forEach(model -> actionGroup.add(createOpenAIModelAction(model, presentation)));
.forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation)));
actionGroup.add(openaiGroup);
}
if (availableProviders.contains(CUSTOM_OPENAI)) {
actionGroup.addSeparator("Custom OpenAI");
actionGroup.add(createModelAction(
CUSTOM_OPENAI,
ApplicationManager.getApplication().getService(CustomServiceSettings.class)
"Custom: " + ApplicationManager.getApplication().getService(CustomServiceSettings.class)
.getState()
.getTemplate()
.getProviderName(),
Icons.OpenAI,
presentation));
}
if (availableProviders.contains(ANTHROPIC)) {
actionGroup.addSeparator("Anthropic");
actionGroup.add(createModelAction(
ANTHROPIC,
"Anthropic (Claude)",
Icons.Anthropic,
presentation));
}
if (availableProviders.contains(AZURE)) {
actionGroup.addSeparator("Azure");
actionGroup.add(
createModelAction(AZURE, "Azure OpenAI", Icons.Azure, presentation));
}
if (availableProviders.contains(GOOGLE)) {
actionGroup.addSeparator("Google");
actionGroup.add(createModelAction(
GOOGLE,
"Google (Gemini)",
Icons.Google,
presentation));
var googleGroup = DefaultActionGroup.createPopupGroup(() -> "Google (Gemini)");
googleGroup.getTemplatePresentation().setIcon(Icons.Google);
Arrays.stream(GoogleModel.values())
.forEach(model ->
googleGroup.add(createGoogleModelAction(model, presentation)));
actionGroup.add(googleGroup);
}
if (availableProviders.contains(LLAMA_CPP)) {
actionGroup.addSeparator("LLaMA C/C++");
actionGroup.addSeparator("Local Providers");
actionGroup.add(createModelAction(
LLAMA_CPP,
getLlamaCppPresentationText(),
Icons.Llama,
presentation));
}
if (availableProviders.contains(OLLAMA)) {
actionGroup.addSeparator("Ollama");
createOllamaModelAction("Default model", presentation);
var ollamaGroup = DefaultActionGroup.createPopupGroup(() -> "Ollama");
ollamaGroup.getTemplatePresentation().setIcon(Icons.Ollama);
ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.getAvailableModels()
.forEach(model ->
actionGroup.add(createOllamaModelAction(model, presentation)));
ollamaGroup.add(createOllamaModelAction(model, presentation)));
actionGroup.add(ollamaGroup);
}

return actionGroup;
Expand Down Expand Up @@ -235,11 +239,20 @@ private String getSelectedHuggingFace() {
huggingFaceModel.getQuantization());
}

private AnAction createModelAction(
ServiceType serviceType,
String label,
Icon icon,
Presentation comboBoxPresentation) {
return createModelAction(serviceType, label, icon, comboBoxPresentation, null);
}

private AnAction createModelAction(
ServiceType serviceType,
String label,
Icon icon,
Presentation comboBoxPresentation) {
Presentation comboBoxPresentation,
Runnable onModelChanged) {
return new DumbAwareAction(label, "", icon) {
@Override
public void update(@NotNull AnActionEvent event) {
Expand All @@ -249,7 +262,10 @@ public void update(@NotNull AnActionEvent event) {

@Override
public void actionPerformed(@NotNull AnActionEvent e) {
handleModelChange(serviceType, label, label, icon, comboBoxPresentation);
if (onModelChanged != null) {
onModelChanged.run();
}
handleModelChange(serviceType, label, icon, comboBoxPresentation);
}

@Override
Expand All @@ -262,7 +278,6 @@ public void actionPerformed(@NotNull AnActionEvent e) {
private void handleModelChange(
ServiceType serviceType,
String label,
String code,
Icon icon,
Presentation comboBoxPresentation) {
GeneralSettings.getCurrentState().setSelectedService(serviceType);
Expand All @@ -272,93 +287,34 @@ private void handleModelChange(
}

private AnAction createCodeGPTModelAction(CodeGPTModel model, Presentation comboBoxPresentation) {
return new DumbAwareAction(model.getName(), "", model.getIcon()) {
@Override
public void update(@NotNull AnActionEvent event) {
var presentation = event.getPresentation();
presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText()));
}

@Override
public void actionPerformed(@NotNull AnActionEvent e) {
ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
.setModel(model.getCode());
handleModelChange(
CODEGPT,
model.getName(),
model.getCode(),
model.getIcon(),
comboBoxPresentation);
}

@Override
public @NotNull ActionUpdateThread getActionUpdateThread() {
return ActionUpdateThread.BGT;
}
};
return createModelAction(CODEGPT, model.getName(), model.getIcon(), comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
.setModel(model.getCode()));
}

private AnAction createOllamaModelAction(
String model,
Presentation comboBoxPresentation
) {
return new DumbAwareAction(model, "", Icons.Ollama) {
@Override
public void update(@NotNull AnActionEvent event) {
var presentation = event.getPresentation();
presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText()));
}

@Override
public void actionPerformed(@NotNull AnActionEvent e) {
ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.setModel(model);
handleModelChange(
OLLAMA,
model,
model,
Icons.Ollama,
comboBoxPresentation);
}

@Override
public @NotNull ActionUpdateThread getActionUpdateThread() {
return ActionUpdateThread.BGT;
}
};
private AnAction createOllamaModelAction(String model, Presentation comboBoxPresentation) {
return createModelAction(OLLAMA, model, Icons.Ollama, comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.setModel(model));
}

private AnAction createOpenAIModelAction(
OpenAIChatCompletionModel model,
Presentation comboBoxPresentation) {
createModelAction(OPENAI, model.getDescription(), Icons.OpenAI,
comboBoxPresentation);
return new DumbAwareAction(model.getDescription(), "", Icons.OpenAI) {
@Override
public void update(@NotNull AnActionEvent event) {
var presentation = event.getPresentation();
presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText()));
}

@Override
public void actionPerformed(@NotNull AnActionEvent e) {
OpenAISettings.getCurrentState().setModel(model.getCode());
handleModelChange(
OPENAI,
model.getDescription(),
model.getCode(),
Icons.OpenAI,
comboBoxPresentation);
}
OpenAIChatCompletionModel model,
Presentation comboBoxPresentation) {
return createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, comboBoxPresentation,
() -> OpenAISettings.getCurrentState().setModel(model.getCode()));
}

@Override
public @NotNull ActionUpdateThread getActionUpdateThread() {
return ActionUpdateThread.BGT;
}
};
private AnAction createGoogleModelAction(GoogleModel model, Presentation comboBoxPresentation) {
return createModelAction(GOOGLE, model.getDescription(), Icons.Google, comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(GoogleSettings.class)
.getState()
.setModel(model.getCode()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@ import javax.swing.Icon

object CodeGPTAvailableModels {

val DEFAULT_CHAT_MODEL = CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL)
val DEFAULT_CODE_MODEL = CodeGPTModel("GPT-3.5 Turbo Instruct", "gpt-3.5-turbo-instruct", Icons.OpenAI, INDIVIDUAL)

@JvmStatic
fun getToolWindowModels(pricingPlan: PricingPlan?): List<CodeGPTModel> {
val anonymousModels = listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS)
)
if (pricingPlan == null) {
return anonymousModels
}
return when (pricingPlan) {
ANONYMOUS -> anonymousModels
null, ANONYMOUS -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS)
)

FREE -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
Expand All @@ -32,12 +31,19 @@ object CodeGPTAvailableModels {
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
)

else -> BASE_CHAT_MODELS
INDIVIDUAL -> listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL),
CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
)
}
}

@JvmStatic
val BASE_CHAT_MODELS: List<CodeGPTModel> = listOf(
val ALL_CHAT_MODELS: List<CodeGPTModel> = listOf(
CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL),
CodeGPTModel("GPT-4o mini", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS),
CodeGPTModel("Claude 3 Opus", "claude-3-opus", Icons.Anthropic, INDIVIDUAL),
Expand All @@ -46,10 +52,6 @@ object CodeGPTAvailableModels {
CodeGPTModel("Llama 3 (70B)", "llama-3-70b", Icons.Meta, FREE),
CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL),
CodeGPTModel("DBRX", "dbrx", Icons.Databricks, INDIVIDUAL),
)

@JvmStatic
val ALL_CHAT_MODELS: List<CodeGPTModel> = BASE_CHAT_MODELS + listOf(
CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS),
CodeGPTModel("Code Llama (70B)", "codellama:chat", Icons.Meta, FREE),
CodeGPTModel("Mixtral (8x22B)", "mixtral-8x22b", Icons.CodeGPTModel, FREE),
Expand All @@ -58,8 +60,8 @@ object CodeGPTAvailableModels {
)

@JvmStatic
val CODE_MODELS: List<CodeGPTModel> = listOf(
CodeGPTModel("GPT-3.5 Turbo Instruct", "gpt-3.5-turbo-instruct", Icons.OpenAI, INDIVIDUAL),
val ALL_CODE_MODELS: List<CodeGPTModel> = listOf(
DEFAULT_CODE_MODEL,
CodeGPTModel("StarCoder (16B)", "starcoder-16b", Icons.CodeGPTModel, FREE),
CodeGPTModel("StarCoder (7B) - FREE", "starcoder-7b", Icons.CodeGPTModel, FREE),
CodeGPTModel("WizardCoder Python (34B)", "wizardcoder-python", Icons.CodeGPTModel, FREE),
Expand All @@ -68,7 +70,7 @@ object CodeGPTAvailableModels {

@JvmStatic
fun findByCode(code: String?): CodeGPTModel? {
return ALL_CHAT_MODELS.union(CODE_MODELS).firstOrNull { it.code == code }
return ALL_CHAT_MODELS.union(ALL_CODE_MODELS).firstOrNull { it.code == code }
}
}

Expand All @@ -77,4 +79,4 @@ data class CodeGPTModel(
val code: String,
val icon: Icon,
val individual: PricingPlan
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CodeGPTServiceForm {
ComboBox(ListComboBoxModel(CodeGPTAvailableModels.ALL_CHAT_MODELS)).apply {
val chatModel = service<CodeGPTServiceSettings>().state.chatCompletionSettings.model
selectedItem = CodeGPTAvailableModels.findByCode(chatModel)
?: CodeGPTAvailableModels.BASE_CHAT_MODELS[0]
?: CodeGPTAvailableModels.DEFAULT_CHAT_MODEL
renderer = CustomComboBoxRenderer()
}

Expand All @@ -38,10 +38,10 @@ class CodeGPTServiceForm {
)

private val codeCompletionModelComboBox =
ComboBox(ListComboBoxModel(CodeGPTAvailableModels.CODE_MODELS)).apply {
ComboBox(ListComboBoxModel(CodeGPTAvailableModels.ALL_CODE_MODELS)).apply {
val codeModel = service<CodeGPTServiceSettings>().state.codeCompletionSettings.model
selectedItem = CodeGPTAvailableModels.findByCode(codeModel)
?: CodeGPTAvailableModels.CODE_MODELS[0]
?: CodeGPTAvailableModels.DEFAULT_CODE_MODEL
renderer = CustomComboBoxRenderer()
}

Expand Down

0 comments on commit 2d30472

Please sign in to comment.