diff --git a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java index da1a093c..e1e96e37 100644 --- a/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java +++ b/src/main/java/ee/carlrobert/codegpt/toolwindow/chat/ui/textarea/ModelComboBoxAction.java @@ -28,9 +28,11 @@ import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTModel; import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTServiceSettings; import ee.carlrobert.codegpt.settings.service.custom.CustomServiceSettings; +import ee.carlrobert.codegpt.settings.service.google.GoogleSettings; import ee.carlrobert.codegpt.settings.service.llama.LlamaSettings; import ee.carlrobert.codegpt.settings.service.ollama.OllamaSettings; import ee.carlrobert.codegpt.settings.service.openai.OpenAISettings; +import ee.carlrobert.llm.client.google.models.GoogleModel; import ee.carlrobert.llm.client.openai.completion.OpenAIChatCompletionModel; import java.util.Arrays; import java.util.List; @@ -96,12 +98,15 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta } if (availableProviders.contains(OPENAI)) { actionGroup.addSeparator("OpenAI"); + var openaiGroup = DefaultActionGroup.createPopupGroup(() -> "OpenAI"); + openaiGroup.getTemplatePresentation().setIcon(Icons.OpenAI); List.of( OpenAIChatCompletionModel.GPT_4_O, OpenAIChatCompletionModel.GPT_4_O_MINI, OpenAIChatCompletionModel.GPT_4_VISION_PREVIEW, OpenAIChatCompletionModel.GPT_4_0125_128k) - .forEach(model -> actionGroup.add(createOpenAIModelAction(model, presentation))); + .forEach(model -> openaiGroup.add(createOpenAIModelAction(model, presentation))); + actionGroup.add(openaiGroup); } if (availableProviders.contains(CUSTOM_OPENAI)) { actionGroup.addSeparator("Custom OpenAI"); @@ -129,11 +134,12 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta } if (availableProviders.contains(GOOGLE)) { actionGroup.addSeparator("Google"); - actionGroup.add(createModelAction( - GOOGLE, - "Google (Gemini)", - Icons.Google, - presentation)); + var googleGroup = DefaultActionGroup.createPopupGroup(() -> "Google (Gemini)"); + googleGroup.getTemplatePresentation().setIcon(Icons.Google); + Arrays.stream(GoogleModel.values()) + .forEach(model -> + googleGroup.add(createGoogleModelAction(model, presentation))); + actionGroup.add(googleGroup); } if (availableProviders.contains(LLAMA_CPP)) { actionGroup.addSeparator("LLaMA C/C++"); @@ -145,13 +151,15 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta } if (availableProviders.contains(OLLAMA)) { actionGroup.addSeparator("Ollama"); - createOllamaModelAction("Default model", presentation); + var ollamaGroup = DefaultActionGroup.createPopupGroup(() -> "Ollama"); + ollamaGroup.getTemplatePresentation().setIcon(Icons.Ollama); ApplicationManager.getApplication() .getService(OllamaSettings.class) .getState() .getAvailableModels() .forEach(model -> - actionGroup.add(createOllamaModelAction(model, presentation))); + ollamaGroup.add(createOllamaModelAction(model, presentation))); + actionGroup.add(ollamaGroup); } return actionGroup; @@ -235,11 +243,20 @@ private String getSelectedHuggingFace() { huggingFaceModel.getQuantization()); } + private AnAction createModelAction( + ServiceType serviceType, + String label, + Icon icon, + Presentation comboBoxPresentation) { + return createModelAction(serviceType, label, icon, comboBoxPresentation, null); + } + private AnAction createModelAction( ServiceType serviceType, String label, Icon icon, - Presentation comboBoxPresentation) { + Presentation comboBoxPresentation, + Runnable onModelChanged) { return new DumbAwareAction(label, "", icon) { @Override public void update(@NotNull AnActionEvent event) { @@ -249,7 +266,10 @@ public void update(@NotNull AnActionEvent event) { @Override public void actionPerformed(@NotNull AnActionEvent e) { - handleModelChange(serviceType, label, label, icon, comboBoxPresentation); + if (onModelChanged != null) { + onModelChanged.run(); + } + handleModelChange(serviceType, label, icon, comboBoxPresentation); } @Override @@ -262,7 +282,6 @@ public void actionPerformed(@NotNull AnActionEvent e) { private void handleModelChange( ServiceType serviceType, String label, - String code, Icon icon, Presentation comboBoxPresentation) { GeneralSettings.getCurrentState().setSelectedService(serviceType); @@ -272,93 +291,34 @@ private void handleModelChange( } private AnAction createCodeGPTModelAction(CodeGPTModel model, Presentation comboBoxPresentation) { - return new DumbAwareAction(model.getName(), "", model.getIcon()) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - ApplicationManager.getApplication().getService(CodeGPTServiceSettings.class) - .getState() - .getChatCompletionSettings() - .setModel(model.getCode()); - handleModelChange( - CODEGPT, - model.getName(), - model.getCode(), - model.getIcon(), - comboBoxPresentation); - } - - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + return createModelAction(CODEGPT, model.getName(), model.getIcon(), comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(CodeGPTServiceSettings.class) + .getState() + .getChatCompletionSettings() + .setModel(model.getCode())); } - private AnAction createOllamaModelAction( - String model, - Presentation comboBoxPresentation - ) { - return new DumbAwareAction(model, "", Icons.Ollama) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - ApplicationManager.getApplication() - .getService(OllamaSettings.class) - .getState() - .setModel(model); - handleModelChange( - OLLAMA, - model, - model, - Icons.Ollama, - comboBoxPresentation); - } - - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + private AnAction createOllamaModelAction(String model, Presentation comboBoxPresentation) { + return createModelAction(OLLAMA, model, Icons.Ollama, comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(OllamaSettings.class) + .getState() + .setModel(model)); } private AnAction createOpenAIModelAction( - OpenAIChatCompletionModel model, - Presentation comboBoxPresentation) { - createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, - comboBoxPresentation); - return new DumbAwareAction(model.getDescription(), "", Icons.OpenAI) { - @Override - public void update(@NotNull AnActionEvent event) { - var presentation = event.getPresentation(); - presentation.setEnabled(!presentation.getText().equals(comboBoxPresentation.getText())); - } - - @Override - public void actionPerformed(@NotNull AnActionEvent e) { - OpenAISettings.getCurrentState().setModel(model.getCode()); - handleModelChange( - OPENAI, - model.getDescription(), - model.getCode(), - Icons.OpenAI, - comboBoxPresentation); - } + OpenAIChatCompletionModel model, + Presentation comboBoxPresentation) { + return createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, comboBoxPresentation, + () -> OpenAISettings.getCurrentState().setModel(model.getCode())); + } - @Override - public @NotNull ActionUpdateThread getActionUpdateThread() { - return ActionUpdateThread.BGT; - } - }; + private AnAction createGoogleModelAction(GoogleModel model, Presentation comboBoxPresentation) { + return createModelAction(GOOGLE, model.getDescription(), Icons.Google, comboBoxPresentation, + () -> ApplicationManager.getApplication() + .getService(GoogleSettings.class) + .getState() + .setModel(model.getCode())); } } diff --git a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt index afab7bbe..195a5d3d 100644 --- a/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt +++ b/src/main/kotlin/ee/carlrobert/codegpt/settings/service/codegpt/CodeGPTAvailableModels.kt @@ -9,19 +9,15 @@ object CodeGPTAvailableModels { @JvmStatic fun getToolWindowModels(pricingPlan: PricingPlan?): List { - val anonymousModels = listOf( - CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), - CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), - CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), - CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), - CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), - CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS) - ) - if (pricingPlan == null) { - return anonymousModels - } return when (pricingPlan) { - ANONYMOUS -> anonymousModels + null, ANONYMOUS -> listOf( + CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), + CodeGPTModel("Claude 3.5 Sonnet", "claude-3.5-sonnet", Icons.Anthropic, INDIVIDUAL), + CodeGPTModel("Llama 3.1 (405B)", "llama-3.1-405b", Icons.Meta, INDIVIDUAL), + CodeGPTModel("DeepSeek Coder V2", "deepseek-coder-v2", Icons.DeepSeek, INDIVIDUAL), + CodeGPTModel("GPT-4o mini - FREE", "gpt-4o-mini", Icons.OpenAI, ANONYMOUS), + CodeGPTModel("Llama 3 (8B) - FREE", "llama-3-8b", Icons.Meta, ANONYMOUS) + ) FREE -> listOf( CodeGPTModel("GPT-4o", "gpt-4o", Icons.OpenAI, INDIVIDUAL), @@ -77,4 +73,4 @@ data class CodeGPTModel( val code: String, val icon: Icon, val individual: PricingPlan -) \ No newline at end of file +)