Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Dowload Buttons #719

Merged
merged 26 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
757e753
This creates the component which will populate the Download Tab with …
anfee1 Mar 2, 2021
3563314
Making a place for the download buttons.
anfee1 Mar 2, 2021
3dc1bf2
Adding the Model Download Handler allowing the backend to feed the li…
anfee1 Mar 2, 2021
89f48c4
Getting rid of some of the test code.
anfee1 Mar 2, 2021
a491d98
Improve Block usability (#712)
stu1130 Mar 2, 2021
5ca07f4
Removing unnecessary logging messages.
anfee1 Mar 3, 2021
eb8d51d
block factory init commit (#697)
lanking520 Mar 3, 2021
cb352ad
[DOCS] Fixing TrainingListener documentation (#718)
aksrajvanshi Mar 3, 2021
48cf663
Fix DJL serving flaky test for mac (#721)
frankfliu Mar 4, 2021
28a32ff
Fixing all of the nits.
anfee1 Mar 4, 2021
c9d28c8
Getting rid of unnecessary methods.
anfee1 Mar 4, 2021
a059417
update onnxruntime along with String tensor (#724)
lanking520 Mar 5, 2021
347eb07
Add profiler doc (#722)
stu1130 Mar 5, 2021
a363db7
Resolving some comments.
anfee1 Mar 5, 2021
9d55e6e
Using a better criteria incase multiple models have the same name.
anfee1 Mar 5, 2021
b4a1cc0
Fixing the java doc.
anfee1 Mar 5, 2021
a66e168
Configure verbose of mxnet extra libraries (#728)
zachgk Mar 8, 2021
5aa09a2
Added a TODO for using the artifact repo to get the base uri.
anfee1 Mar 8, 2021
f881e4d
paddlepaddle CN notebook (#730)
lanking520 Mar 8, 2021
a6a2232
add EI documentation (#733)
lanking520 Mar 9, 2021
a90129e
allow pytorch stream model loading (#729)
lanking520 Mar 9, 2021
c6aebe0
add NDList decode from inputStream (#734)
lanking520 Mar 9, 2021
8342d44
Remove memory scope and improve memory management (#695)
zachgk Mar 9, 2021
43e5891
Remove erroneous random forest application (#726)
zachgk Mar 9, 2021
2158e99
Minor fixes on duplicated code (#736)
lanking520 Mar 9, 2021
f29daf8
Trying to rebase to fix PR.
anfee1 Mar 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving.central;

import ai.djl.serving.central.handler.HttpStaticFileServerHandler;
import ai.djl.serving.central.handler.ModelDownloadHandler;
import ai.djl.serving.central.handler.ModelMetaDataHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
Expand Down Expand Up @@ -54,6 +55,7 @@ public void initChannel(SocketChannel ch) {
pipeline.addLast(new HttpServerCodec());
pipeline.addLast(new HttpObjectAggregator(65536));
pipeline.addLast(new ChunkedWriteHandler());
pipeline.addLast(new ModelDownloadHandler());
pipeline.addLast(new ModelMetaDataHandler());
pipeline.addLast(new HttpStaticFileServerHandler());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.central.handler;

import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.central.http.BadRequestException;
import ai.djl.serving.central.responseencoder.HttpRequestResponse;
import ai.djl.serving.central.utils.ModelUri;
import ai.djl.serving.central.utils.NettyUtils;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.io.IOException;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;

/**
* A handler to handle download requests from the ModelView.
*
* @author [email protected]
*/
public class ModelDownloadHandler extends SimpleChannelInboundHandler<FullHttpRequest> {

HttpRequestResponse jsonResponse;

/** Constructs a ModelDownloadHandler. */
public ModelDownloadHandler() {
jsonResponse = new HttpRequestResponse();
}

/**
* Handles the deployment request by forwarding the request to the serving-instance.
*
* @param ctx the context
* @param request the full request
*/
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request)
throws IOException, ModelNotFoundException {
QueryStringDecoder decoder = new QueryStringDecoder(request.uri());
String modelName = NettyUtils.getParameter(decoder, "modelName", null);
String modelGroupId = NettyUtils.getParameter(decoder, "groupId", null);
String modelArtifactId = NettyUtils.getParameter(decoder, "artifactId", null);
CompletableFuture.supplyAsync(
() -> {
try {
if (modelName != null) {
return ModelUri.uriFinder(
modelArtifactId, modelGroupId, modelName);
} else {
throw new BadRequestException("modelName is mandatory.");
}

} catch (IOException | ModelNotFoundException ex) {
throw new IllegalArgumentException(ex.getMessage(), ex);
}
})
.exceptionally((ex) -> Collections.emptyMap())
anfee1 marked this conversation as resolved.
Show resolved Hide resolved
.thenAccept(uriMap -> jsonResponse.sendAsJson(ctx, request, uriMap));
}

/** {@inheritDoc} */
@Override
public boolean acceptInboundMessage(Object msg) {
FullHttpRequest request = (FullHttpRequest) msg;

String uri = request.uri();
return uri.startsWith("/serving/models?");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.central.http;

/** Thrown when a bad HTTP request is received. */
public class BadRequestException extends IllegalArgumentException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BadRequest is a little vague. How do you plan to use this exception? Is it for something like the request could not be understood because it was not in the proper format "malformed"? Could the request be understood, but what it was requesting did not fully constitute a valid request "invalid"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it was to say if somehow there is an error where a request did not have all of the parameters that are expected in the request.

Copy link
Contributor

@ebamberg ebamberg Mar 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"serving" uses exactly the same exception for the same purpose I guess. please correct me if I am wrong, so this would be constistent to "serving", especially when we move central as plugin into "serving"
+1 from me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if serving has the same one then we should keep it to prioritize consistency. This refactor could be revisited later once serving and central are merged


static final long serialVersionUID = 1L;

/**
* Constructs an {@code BadRequestException} with the specified detail message.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
*/
public BadRequestException(String message) {
super(message);
}

/**
* Constructs an {@code BadRequestException} with the specified detail message and a root cause.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
* @param cause root cause
*/
public BadRequestException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
/** Contains HTTP codes. */
package ai.djl.serving.central.http;
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.central.responseencoder;

import ai.djl.modality.Classifications;
import ai.djl.modality.Classifications.ClassificationsSerializer;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.Metadata;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.CharsetUtil;
import java.lang.reflect.Modifier;

/**
* Serialize to json and send the response to the client.
*
* @author [email protected]
*/
public class HttpRequestResponse {

private static final Gson GSON_WITH_TRANSIENT_FIELDS =
new GsonBuilder()
.setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
.setPrettyPrinting()
.excludeFieldsWithModifiers(Modifier.STATIC)
.registerTypeAdapter(Classifications.class, new ClassificationsSerializer())
.registerTypeAdapter(DetectedObjects.class, new ClassificationsSerializer())
.registerTypeAdapter(Metadata.class, new MetaDataSerializer())
.registerTypeAdapter(
Double.class,
(JsonSerializer<Double>)
(src, t, ctx) -> {
long v = src.longValue();
if (src.equals(Double.valueOf(String.valueOf(v)))) {
return new JsonPrimitive(v);
}
return new JsonPrimitive(src);
})
.create();

/**
* send a response to the client.
*
* @param ctx channel context
* @param request full request
* @param entity the response
*/
public void sendAsJson(ChannelHandlerContext ctx, FullHttpRequest request, Object entity) {

String serialized = GSON_WITH_TRANSIENT_FIELDS.toJson(entity);
ByteBuf buffer = ctx.alloc().buffer(serialized.length());
buffer.writeCharSequence(serialized, CharsetUtil.UTF_8);

FullHttpResponse response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer);
response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8");
boolean keepAlive = HttpUtil.isKeepAlive(request);
this.sendAndCleanupConnection(ctx, response, keepAlive);
}

/**
* send content of a ByteBuffer as response to the client.
*
* @param ctx channel context
* @param buffer response buffer
*/
public void sendByteBuffer(ChannelHandlerContext ctx, ByteBuf buffer) {

FullHttpResponse response =
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer);
response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json; charset=UTF-8");
this.sendAndCleanupConnection(ctx, response, false);
}

/**
* If Keep-Alive is disabled, attaches "Connection: close" header to the response and closes the
* connection after the response being sent.
*
* @param ctx context
* @param response full response
* @param keepAlive is alive or not
*/
private void sendAndCleanupConnection(
ChannelHandlerContext ctx, FullHttpResponse response, boolean keepAlive) {
HttpUtil.setContentLength(response, response.content().readableBytes());
if (!keepAlive) {
// We're going to close the connection as soon as the response is sent,
// so we should also make it clear for the client.
response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}

ChannelFuture flushPromise = ctx.writeAndFlush(response);

if (!keepAlive) {
// Close the connection as soon as the response is sent.
flushPromise.addListener(ChannelFutureListener.CLOSE);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.central.utils;

import ai.djl.Application;
import ai.djl.repository.Artifact;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import java.io.IOException;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/** A class to find the URIs when given a model name. */
public final class ModelUri {

// TODO: Use the artifact repository to create base URI
private static URI base = URI.create("https://mlrepo.djl.ai/");

private ModelUri() {}

/**
* Takes in a model name, artifactId, and groupId to return a Map of download URIs.
*
* @param artifactId is the artifactId of the model
* @param groupId is the groupId of the model
* @param name is the name of the model
* @return a map of download URIs
* @throws IOException if the uri could not be found
* @throws ModelNotFoundException if Model can not be found
*/
public static Map<String, URI> uriFinder(String artifactId, String groupId, String name)
throws IOException, ModelNotFoundException {
Criteria<?, ?> criteria =
Criteria.builder()
.optModelName(name)
.optGroupId(groupId)
.optArtifactId(artifactId)
.build();
Map<Application, List<Artifact>> models = ModelZoo.listModels(criteria);
Map<String, URI> uris = new ConcurrentHashMap<>();
models.forEach(
(app, list) -> {
list.forEach(
artifact -> {
for (Map.Entry<String, Artifact.Item> entry :
artifact.getFiles().entrySet()) {
URI fileUri = URI.create(entry.getValue().getUri());
URI baseUri = artifact.getMetadata().getRepositoryUri();
if (!fileUri.isAbsolute()) {
fileUri = base.resolve(baseUri).resolve(fileUri);
}
uris.put(entry.getKey(), fileUri);
}
});
});
return uris;
}
}
Loading