From d016da12c41810549676e388be2c3c1e04fc9f9a Mon Sep 17 00:00:00 2001 From: Gautam Kumar Date: Mon, 1 Mar 2021 14:06:58 -0800 Subject: [PATCH] Minor fix in SageMaker SDK --- .../java/ai/djl/aws/sagemaker/SageMaker.java | 24 ++++++++++++++++--- .../ai/djl/aws/sagemaker/SageMakerTest.java | 1 + 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java b/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java index 47481895788..bc2ee00a139 100644 --- a/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java +++ b/extensions/aws-ai/src/main/java/ai/djl/aws/sagemaker/SageMaker.java @@ -105,7 +105,11 @@ private SageMaker(Builder builder) { s3 = builder.s3; iam = builder.iam; model = builder.model; - modelName = model.getName(); + if (builder.modelName != null) { + modelName = builder.modelName; + } else { + modelName = model.getName(); + } bucketName = builder.bucketName; bucketPath = builder.bucketPath; executionRole = builder.executionRole; @@ -455,6 +459,7 @@ public static final class Builder { String containerImage; String endpointConfigName; String endpointName; + String modelName; String instanceType = "ml.m4.xlarge"; int instanceCount = 1; SageMakerClient sageMaker; @@ -555,6 +560,19 @@ public Builder optEndpointName(String endpointName) { return this; } + /** + * Sets the optional model name to create. + * + *

If {@code modelName} is not set, model name will be used as model name. + * + * @param modelName the model name to create + * @return the builder + */ + public Builder optModelName(String modelName) { + this.modelName = modelName; + return this; + } + /** * Sets the optional instance type to launch the endpoint. * @@ -658,10 +676,10 @@ public SageMaker build() { bucketPath = bucketPath.substring(1); } if (endpointConfigName == null) { - endpointConfigName = model.getName(); + endpointConfigName = modelName == null ? model.getName() : modelName; } if (endpointName == null) { - endpointName = model.getName(); + endpointName = modelName == null ? model.getName() : modelName; } return new SageMaker(this); diff --git a/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java b/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java index b135a285119..e28426f2df0 100644 --- a/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java +++ b/extensions/aws-ai/src/test/java/ai/djl/aws/sagemaker/SageMakerTest.java @@ -54,6 +54,7 @@ public void testDeployModel() throws IOException, ModelException { SageMaker.builder() .setModel(model) .optBucketName("djl-sm-test") + .optModelName("resnet") .optContainerImage("125045733377.dkr.ecr.us-east-1.amazonaws.com/djl") .optExecutionRole( "arn:aws:iam::125045733377:role/service-role/DJLSageMaker-ExecutionRole-20210213T1027050")