Skip to content

Commit

Permalink
Download 10 split files for DBRX
Browse files Browse the repository at this point in the history
  • Loading branch information
reneleonhardt committed Apr 21, 2024
1 parent d3a610c commit c87c1b1
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import java.net.MalformedURLException;
import java.net.URL;
import java.util.List;
import java.util.stream.IntStream;

public enum HuggingFaceModel {

Expand Down Expand Up @@ -89,21 +91,26 @@ public String getCode() {
return name();
}

public String getFileName() {
public List<String> getFileNames() {
if ("TheBloke".equals(user)) {
return modelName.toLowerCase().replace("-gguf", format(".Q%d_K_M.gguf", quantization));
return List.of(modelName.toLowerCase()
.replace("-gguf", format(".Q%d_K_M.gguf", quantization)));
}
// TODO: Download all 10 files ;(
return modelName.toLowerCase().replace("-gguf", "-00001-of-00010.gguf");
if ("phymbert".equals(user)) {
return IntStream.range(1, 11).mapToObj(i -> modelName
.replace("-gguf", "-000%02d-of-00010.gguf".formatted(i))).toList();
}
return List.of(modelName);
}

public URL getFileURL() {
try {
return new URL(
"https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), getFileName()));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
public List<URL> getFileURLs() {
return getFileNames().stream().map(file -> {
try {
return new URL("https://huggingface.co/%s/%s/resolve/main/%s".formatted(user, getDirectory(), file));
} catch (MalformedURLException ex) {
throw new RuntimeException(ex);
}
}).toList();
}

public URL getHuggingFaceURL() {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ public InfillPromptTemplate getInfillPromptTemplate() {
public String getActualModelPath() {
return isUseCustomLlamaModel()
? getCustomLlamaModelPath()
: CodeGPTPlugin.getLlamaModelsPath() + File.separator + getSelectedModel().getFileName();
: CodeGPTPlugin.getLlamaModelsPath() + File.separator
+ getSelectedModel().getFileNames().get(0);
}

private JPanel createFormPanelCards() {
Expand Down Expand Up @@ -394,8 +395,9 @@ private TextFieldWithBrowseButton createBrowsableCustomModelTextField(boolean en
}

private boolean isModelExists(HuggingFaceModel model) {
return FileUtil.exists(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
return model.getFileNames().stream().allMatch(filename ->
FileUtil.exists(CodeGPTPlugin.getLlamaModelsPath() + File.separator + filename)
);
}

private AnActionLink createCancelDownloadLink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private boolean validateSelectedModel() {

private boolean isModelExists(HuggingFaceModel model) {
return FileUtil.exists(
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileName());
CodeGPTPlugin.getLlamaModelsPath() + File.separator + model.getFileNames());
}

private void enableForm(JButton serverButton, ServerProgressPanel progressPanel) {
Expand Down
40 changes: 0 additions & 40 deletions src/main/java/ee/carlrobert/codegpt/util/DownloadingUtil.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package ee.carlrobert.codegpt.settings.service.llama.form

import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
import com.intellij.openapi.project.Project
import ee.carlrobert.codegpt.CodeGPTBundle
import ee.carlrobert.codegpt.completions.HuggingFaceModel
import ee.carlrobert.codegpt.util.DownloadingUtil
import ee.carlrobert.codegpt.util.file.FileUtil.copyFileWithProgress
import java.io.IOException
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.function.Consumer
import javax.swing.DefaultComboBoxModel

class DownloadModelAction(
private val onDownload: Consumer<ProgressIndicator>,
private val onDownloaded: Runnable,
private val onFailed: Consumer<Exception>,
private val onUpdateProgress: Consumer<String>,
private val comboBoxModel: DefaultComboBoxModel<HuggingFaceModel>
) : AnAction() {

override fun actionPerformed(e: AnActionEvent) {
ProgressManager.getInstance().run(DownloadBackgroundTask(e.project))
}

internal inner class DownloadBackgroundTask(project: Project?) : Task.Backgroundable(
project,
CodeGPTBundle.get("settingsConfigurable.service.llama.progress.downloadingModel.title"),
true
) {
override fun run(indicator: ProgressIndicator) {
val model = comboBoxModel.selectedItem as HuggingFaceModel
val urls = model.fileURLs
val numberOfFiles = urls.size
var errorOccured = false
for (i in 1..numberOfFiles + 1) {
if (errorOccured || indicator.isCanceled) {
break
}
val executorService = Executors.newSingleThreadScheduledExecutor()
var progressUpdateScheduler: ScheduledFuture<*>? = null
val url = urls[i - 1]

try {
onDownload.accept(indicator)

indicator.isIndeterminate = false
indicator.text = String.format(
CodeGPTBundle.get(
"settingsConfigurable.service.llama.progress.downloadingModelIndicator.text"
),
model.fileNames[i - 1]
)

val fileSize = url.openConnection().contentLengthLong
val bytesRead = longArrayOf(0)
val startTime = System.currentTimeMillis()

progressUpdateScheduler = executorService.scheduleAtFixedRate(
{
onUpdateProgress.accept(
DownloadingUtil.getFormattedDownloadProgress(
i,
numberOfFiles,
startTime,
fileSize,
bytesRead[0]
)
)
},
0, 1, TimeUnit.SECONDS
)
copyFileWithProgress(model.fileNames[i - 1], url, bytesRead, fileSize, indicator)
} catch (ex: IOException) {
LOG.error("Unable to download", ex, url.toString())
onFailed.accept(ex)
errorOccured = true
} finally {
progressUpdateScheduler?.cancel(true)
executorService.shutdown()
}
}
}

override fun onSuccess() {
onDownloaded.run()
}
}

companion object {
private val LOG = Logger.getInstance(DownloadModelAction::class.java)
}
}
40 changes: 40 additions & 0 deletions src/main/kotlin/ee/carlrobert/codegpt/util/DownloadingUtil.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ee.carlrobert.codegpt.util

import ee.carlrobert.codegpt.util.file.FileUtil.convertFileSize

object DownloadingUtil {
private const val BYTES_IN_MB = 1024 * 1024

fun getFormattedDownloadProgress(
fileNumber: Int, fileCount: Int, startTime: Long,
fileSize: Long, bytesRead: Long
): String {
val timeElapsed = System.currentTimeMillis() - startTime

val speed = (bytesRead.toDouble() / timeElapsed) * 1000 / BYTES_IN_MB
val percent = bytesRead.toDouble() / fileSize * 100
val downloadedMB = bytesRead.toDouble() / BYTES_IN_MB
val totalMB = fileSize.toDouble() / BYTES_IN_MB
val remainingMB = totalMB - downloadedMB

return String.format(
"File %d/%d: %s of %s (%.2f%%), Speed: %.2f MB/sec, Time left: %s",
fileNumber,
fileCount,
convertFileSize(downloadedMB.toLong() * BYTES_IN_MB),
convertFileSize(totalMB.toLong() * BYTES_IN_MB),
percent,
speed,
getTimeLeftFormattedString(speed, remainingMB)
)
}

private fun getTimeLeftFormattedString(speed: Double, remainingMB: Double): String {
val timeLeftSec = if (speed > 0) remainingMB / speed else 0.0
val hours = (timeLeftSec / 3600).toLong()
val minutes = ((timeLeftSec % 3600) / 60).toLong()
val seconds = (timeLeftSec % 60).toLong()

return String.format("%02d:%02d:%02d", hours, minutes, seconds)
}
}

0 comments on commit c87c1b1

Please sign in to comment.