Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: feat: Support DBRX model in Llama #462

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/cpp/llama.cpp
Submodule llama.cpp updated 307 files
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import java.net.MalformedURLException;
import java.net.URL;
import java.util.List;
import java.util.stream.IntStream;

public enum HuggingFaceModel {

Expand Down Expand Up @@ -52,7 +54,14 @@ public enum HuggingFaceModel {
LLAMA_3_8B_Q8_0(8, 8, "Meta-Llama-3-8B-Instruct-Q8_0.gguf", "lmstudio-community"),
LLAMA_3_70B_IQ1(70, 1, "Meta-Llama-3-70B-Instruct-IQ1_M.gguf", "lmstudio-community"),
LLAMA_3_70B_IQ2_XS(70, 2, "Meta-Llama-3-70B-Instruct-IQ2_XS.gguf", "lmstudio-community"),
LLAMA_3_70B_Q4_K_M(70, 4, "Meta-Llama-3-70B-Instruct-Q4_K_M.gguf", "lmstudio-community");
LLAMA_3_70B_Q4_K_M(70, 4, "Meta-Llama-3-70B-Instruct-Q4_K_M.gguf", "lmstudio-community"),

DBRX_12B_Q3_K_M(12, 3, "dbrx-16x12b-instruct-q3_k_m-gguf", "phymbert"),
DBRX_12B_Q4_0(12, 4, "dbrx-16x12b-instruct-q4_0-gguf", "phymbert"),
DBRX_12B_Q6_K(12, 6, "dbrx-16x12b-instruct-q6_k-gguf", "phymbert"),
DBRX_12B_Q8_0(12, 8, "dbrx-16x12b-instruct-q8_0-gguf", "phymbert"),
DBRX_12B_Q3_S(12, 3, "dbrx-16x12b-instruct-iq3_s-gguf", "phymbert"),
DBRX_12B_Q3_XXS(12, 3, "dbrx-16x12b-instruct-iq3_xxs-gguf", "phymbert");

private final int parameterSize;
private final int quantization;
Expand Down Expand Up @@ -82,20 +91,26 @@ public String getCode() {
return name();
}

public String getFileName() {
public List<String> getFileNames() {
if ("TheBloke".equals(user)) {
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
return List.of(modelName.toLowerCase()
.replace("-gguf", format(".Q%d_K_M.gguf", quantization)));
}
return modelName;
if ("phymbert".equals(user)) {
return IntStream.range(1, 11).mapToObj(i -> modelName
.replace("-gguf", "-000%02d-of-00010.gguf".formatted(i))).toList();
}
return List.of(modelName);
}

public URL getFileURL() {
try {
return new URL(
"https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName()));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
public List<URL> getFileURLs() {
return getFileNames().stream().map(file -> {
try {
return new URL("https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), file));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
}).toList();
}

public URL getHuggingFaceURL() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,21 @@ public enum LlamaModel {
HuggingFaceModel.LLAMA_3_8B_Q8_0,
HuggingFaceModel.LLAMA_3_70B_IQ1,
HuggingFaceModel.LLAMA_3_70B_IQ2_XS,
HuggingFaceModel.LLAMA_3_70B_Q4_K_M));
HuggingFaceModel.LLAMA_3_70B_Q4_K_M)),
DBRX(
"DBRX",
"DBRX is a Mixture-of-Experts (MoE) model with 132B total parameters and 36B live parameters."
+ "Generation speed is significantly faster than LLaMA2-70B, while at the same time "
+ "beating other open source models, such as, LLaMA2-70B, Mixtral, and Grok-1 on "
+ "language understanding, programming, math, and logic.",
PromptTemplate.CHAT_ML,
List.of(
HuggingFaceModel.DBRX_12B_Q3_K_M,
HuggingFaceModel.DBRX_12B_Q4_0,
HuggingFaceModel.DBRX_12B_Q6_K,
HuggingFaceModel.DBRX_12B_Q8_0,
HuggingFaceModel.DBRX_12B_Q3_S,
HuggingFaceModel.DBRX_12B_Q3_XXS));

private final String label;
private final String description;
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ public InfillPromptTemplate getInfillPromptTemplate() {
public String getActualModelPath() {
return isUseCustomLlamaModel()
? getCustomLlamaModelPath()
: CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName();
: CodeGPTPlugin.getLlamaModelsPath() + File.separator
+ getSelectedModel().getFileNames().get(0);
}

private JPanel createFormPanelCards() {
Expand Down Expand Up @@ -394,8 +395,9 @@ private TextFieldWithBrowseButton createBrowsableCustomModelTextField(boolean en
}

private boolean isModelExists(HuggingFaceModel model) {
return FileUtil.exists(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
return model.getFileNames().stream().allMatch(filename ->
FileUtil.exists(CodeGPTPlugin.getLlamaModelsPath() + File.separator + filename)
);
}

private AnActionLink createCancelDownloadLink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private boolean validateSelectedModel() {

private boolean isModelExists(HuggingFaceModel model) {
return FileUtil.exists(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileNames());
}

private void enableForm(JButton serverButton, ServerProgressPanel progressPanel) {
Expand Down
40 changes: 0 additions & 40 deletions src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package ee.carlrobert.codegpt.settings.service.llama.form

import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.completions.HuggingFaceModel
import ee.carlrobert.codegpt.util.DownloadingUtil
import ee.carlrobert.codegpt.util.file.FileUtil.copyFileWithProgress
import java.io.IOException
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.function.Consumer
import javax.swing.DefaultComboBoxModel

class DownloadModelAction(
private val onDownload: Consumer<ProgressIndicator>,
private val onDownloaded: Runnable,
private val onFailed: Consumer<Exception>,
private val onUpdateProgress: Consumer<String>,
private val comboBoxModel: DefaultComboBoxModel<HuggingFaceModel>
) : AnAction() {

override fun actionPerformed(e: AnActionEvent) {
ProgressManager.getInstance().run(DownloadBackgroundTask(e.project))
}

internal inner class DownloadBackgroundTask(project: Project?) : Task.Backgroundable(
project,
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"),
true
) {
override fun run(indicator: ProgressIndicator) {
val model = comboBoxModel.selectedItem as HuggingFaceModel
val urls = model.fileURLs
val numberOfFiles = urls.size
var errorOccured = false
for (i in 1..numberOfFiles + 1) {
if (errorOccured || indicator.isCanceled) {
break
}
val executorService = Executors.newSingleThreadScheduledExecutor()
var progressUpdateScheduler: ScheduledFuture<*>? = null
val url = urls[i - 1]

try {
onDownload.accept(indicator)

indicator.isIndeterminate = false
indicator.text = String.format(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.progress.downloadingModelIndicator.text"
),
model.fileNames[i - 1]
)

val fileSize = url.openConnection().contentLengthLong
val bytesRead = longArrayOf(0)
val startTime = System.currentTimeMillis()

progressUpdateScheduler = executorService.scheduleAtFixedRate(
{
onUpdateProgress.accept(
DownloadingUtil.getFormattedDownloadProgress(
i,
numberOfFiles,
startTime,
fileSize,
bytesRead[0]
)
)
},
0, 1, TimeUnit.SECONDS
)
copyFileWithProgress(model.fileNames[i - 1], url, bytesRead, fileSize, indicator)
} catch (ex: IOException) {
LOG.error("Unable to download", ex, url.toString())
onFailed.accept(ex)
errorOccured = true
} finally {
progressUpdateScheduler?.cancel(true)
executorService.shutdown()
}
}
}

override fun onSuccess() {
onDownloaded.run()
}
}

companion object {
private val LOG = Logger.getInstance(DownloadModelAction::class.java)
}
}
Loading