Skip to content

Commit

Permalink
Pass trust_remote_code arg to djl-convert (#2315)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Aug 14, 2024
1 parent 3173776 commit db7bf29
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -157,10 +158,14 @@ static void convertOnnxModel(ModelInfo<?, ?> info) throws IOException {
modelId = repo.toString();
}
String optimization = info.prop.getProperty("option.optimization");
info.resolvedModelUrl = convertOnnx(modelId, optimization).toUri().toURL().toString();
boolean trustRemoteCode = "true".equals(info.prop.getProperty("option.trust_remote_code"));
info.resolvedModelUrl =
convertOnnx(modelId, optimization, trustRemoteCode).toUri().toURL().toString();
}

private static Path convertOnnx(String modelId, String optimization) throws IOException {
private static Path convertOnnx(String modelId, String optimization, boolean trustRemoteCode)
throws IOException {
logger.info("Converting model to onnx artifacts");
String hash = Utils.hash(modelId);
String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
Path parent = download == null ? Utils.getCacheDir() : Paths.get(download);
Expand All @@ -177,19 +182,22 @@ private static Path convertOnnx(String modelId, String optimization) throws IOEx
throw new IllegalArgumentException("Unsupported optimization level: " + optimization);
}

String[] cmd = {
"djl-convert",
"--output-dir",
repoDir.toAbsolutePath().toString(),
"--output-format",
"OnnxRuntime",
"-m",
modelId,
"--optimize",
optimization,
"--device",
hasCuda ? "cuda" : "cpu"
};
List<String> cmd = new ArrayList<>();
cmd.add("djl-convert");
cmd.add("--output-dir");
cmd.add(repoDir.toAbsolutePath().toString());
cmd.add("--output-format");
cmd.add("OnnxRuntime");
cmd.add("-m");
cmd.add(modelId);
cmd.add("--optimize");
cmd.add(optimization);
cmd.add("--device");
cmd.add(hasCuda ? "cuda" : "cpu");
if (trustRemoteCode) {
cmd.add("--trust-remote-code");
}

boolean success = false;
try {
logger.info("Converting model to onnx artifacts: {}", (Object) cmd);
Expand Down Expand Up @@ -227,6 +235,7 @@ static boolean needConvertRust(ModelInfo<?, ?> info) {

static void convertRustModel(ModelInfo<?, ?> info) throws IOException {
String modelId = info.prop.getProperty("option.model_id");
boolean trustRemoteCode = "true".equals(info.prop.getProperty("option.trust_remote_code"));
if (modelId == null) {
logger.info("model_id not defined, skip rust model conversion.");
return;
Expand All @@ -242,15 +251,18 @@ static void convertRustModel(ModelInfo<?, ?> info) throws IOException {
return;
}

String[] cmd = {
"djl-convert",
"--output-dir",
repoDir.toAbsolutePath().toString(),
"--output-format",
"Rust",
"-m",
modelId
};
List<String> cmd = new ArrayList<>();
cmd.add("djl-convert");
cmd.add("--output-dir");
cmd.add(repoDir.toAbsolutePath().toString());
cmd.add("--output-format");
cmd.add("Rust");
cmd.add("-m");
cmd.add(modelId);
if (trustRemoteCode) {
cmd.add("--trust-remote-code");
}

boolean success = false;
try {
logger.info("Converting model to rust artifacts: {}", (Object) cmd);
Expand Down

0 comments on commit db7bf29

Please sign in to comment.