Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Nov 7, 2024
1 parent be7bbbd commit 1d29e0a
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 67 deletions.
4 changes: 2 additions & 2 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ def register_adapter(inputs: Input):
return Output().error(f"register_adapter_error", message=str(e))

logging.info(
f"Registered adapter {adapter_name} from {adapter_path} successfully")
return Output(message=f"Adapter {adapter_name} registered")
f"Registered adapter {adapter_alias} from {adapter_path} successfully")
return Output(message=f"Adapter {adapter_alias} registered")


def update_adapter(inputs: Input):
Expand Down
3 changes: 2 additions & 1 deletion engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ def _fetch_adapters_from_input(input_map: dict, input_item: Input):
def _validate_adapters(adapters_per_item, adapter_registry):
for adapter_name, adapter_alias in adapters_per_item:
if adapter_name and adapter_name not in adapter_registry:
raise ValueError(f"Adapter {adapter_alias} is not registered")
raise ValueError(
f"Adapter {adapter_alias or adapter_name} is not registered")


def parse_lmi_default_request_rolling_batch(payload):
Expand Down
4 changes: 2 additions & 2 deletions engines/python/src/test/resources/adaptecho/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def update_adapter(inputs: Input):
def unregister_adapter(inputs: Input):
global adapters
name = inputs.get_property("name")
if inputs.contains_key("error"):
if name not in adapters:
return Output().error(f"error",
message=f"Failed to unregister adapter: {name}")
message=f"Adapter {name} not registered.")
del adapters[name]
return Output().add("Successfully unregistered adapter")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ private void handleListAdapters(
WorkerPool<Input, Output> wp =
ModelManager.getInstance().getWorkLoadManager().getWorkerPoolById(modelName);
if (wp == null) {
throw new BadRequestException(404, "The model " + modelName + " was not found");
throw new BadRequestException(
HttpResponseStatus.NOT_FOUND.code(),
"The model " + modelName + " was not found");
}
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
boolean enableLora =
Expand Down Expand Up @@ -180,7 +182,9 @@ private void handleRegisterAdapter(
WorkLoadManager wlm = ModelManager.getInstance().getWorkLoadManager();
WorkerPool<Input, Output> wp = wlm.getWorkerPoolById(modelName);
if (wp == null) {
throw new BadRequestException(404, "The model " + modelName + " was not found");
throw new BadRequestException(
HttpResponseStatus.NOT_FOUND.code(),
"The model " + modelName + " was not found");
}
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
boolean enableLora =
Expand All @@ -206,10 +210,11 @@ private void handleRegisterAdapter(
.whenCompleteAsync(
(o, t) -> {
if (o != null) {
if (o.getCode() < 300) {
modelInfo.registerAdapter(adapter);
if (o.getCode() >= 300) {
throw new BadRequestException(o.getCode(), o.getMessage());
}
handleOutput(o, ctx);
modelInfo.registerAdapter(adapter);
sendOutput(o, ctx);
}
})
.exceptionally(
Expand All @@ -228,7 +233,9 @@ private void handleUpdateAdapter(
WorkLoadManager wlm = ModelManager.getInstance().getWorkLoadManager();
WorkerPool<Input, Output> wp = wlm.getWorkerPoolById(modelName);
if (wp == null) {
throw new BadRequestException(404, "The model " + modelName + " was not found");
throw new BadRequestException(
HttpResponseStatus.NOT_FOUND.code(),
"The model " + modelName + " was not found");
}
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
boolean enableLora =
Expand All @@ -242,47 +249,47 @@ private void handleUpdateAdapter(

if (adapter == null) {
throw new BadRequestException(
404,
HttpResponseStatus.NOT_FOUND.code(),
"The adapter "
+ (adapterAlias == null ? adapterName : adapterAlias)
+ " was not found");
}

Map<String, String> options = new ConcurrentHashMap<>();
Map<String, String> options = new ConcurrentHashMap<>(adapter.getOptions());
for (Map.Entry<String, List<String>> entry : decoder.parameters().entrySet()) {
if (entry.getValue().size() == 1) {
options.put(entry.getKey(), entry.getValue().get(0));
}
}
String src = options.get("src");
boolean pin = Boolean.parseBoolean(options.getOrDefault("pin", "false"));

Adapter<Input, Output> newAdapter =
Adapter.newInstance(
modelInfo,
adapter.getName(),
adapterName,
adapter.getAlias(),
adapter.getSrc(),
adapter.isPin(),
adapter.getOptions());
newAdapter.setAlias(adapterAlias);
if (src != null) {
newAdapter.setSrc(src);
options);
if (adapterAlias != null) {
newAdapter.setAlias(adapterAlias);
}
if (options.containsKey("src")) {
newAdapter.setSrc(options.get("src"));
}
if (adapter.isPin() != pin) {
newAdapter.setPin(pin);
if (options.containsKey("pin")) {
newAdapter.setPin(Boolean.parseBoolean(options.get("pin")));
}
newAdapter.getOptions().putAll(options);

newAdapter
.update(wlm)
.whenCompleteAsync(
(o, t) -> {
if (o != null) {
if (o.getCode() < 300) {
modelInfo.updateAdapter(newAdapter);
if (o.getCode() >= 300) {
throw new BadRequestException(o.getCode(), o.getMessage());
}
handleOutput(o, ctx);
modelInfo.updateAdapter(newAdapter);
sendOutput(o, ctx);
}
})
.exceptionally(
Expand All @@ -297,7 +304,9 @@ private void handleDescribeAdapter(
WorkerPool<Input, Output> wp =
ModelManager.getInstance().getWorkLoadManager().getWorkerPoolById(modelName);
if (wp == null) {
throw new BadRequestException(404, "The model " + modelName + " was not found");
throw new BadRequestException(
HttpResponseStatus.NOT_FOUND.code(),
"The model " + modelName + " was not found");
}
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
boolean enableLora =
Expand All @@ -311,7 +320,7 @@ private void handleDescribeAdapter(

if (adapter == null) {
throw new BadRequestException(
404,
HttpResponseStatus.NOT_FOUND.code(),
"The adapter "
+ (adapterAlias == null ? adapterName : adapterAlias)
+ " was not found");
Expand All @@ -326,7 +335,9 @@ private void handleUnregisterAdapter(
WorkLoadManager wlm = ModelManager.getInstance().getWorkLoadManager();
WorkerPool<Input, Output> wp = wlm.getWorkerPoolById(modelName);
if (wp == null) {
throw new BadRequestException(404, "The model " + modelName + " was not found");
throw new BadRequestException(
HttpResponseStatus.NOT_FOUND.code(),
"The model " + modelName + " was not found");
}
ModelInfo<Input, Output> modelInfo = getModelInfo(wp);
boolean enableLora =
Expand All @@ -340,20 +351,25 @@ private void handleUnregisterAdapter(

if (adapter == null) {
throw new BadRequestException(
404,
HttpResponseStatus.NOT_FOUND.code(),
"The adapter "
+ (adapterAlias == null ? adapterName : adapterAlias)
+ " was not found");
}

if (adapterAlias != null) {
adapter.setAlias(adapterAlias);
}

Adapter.unregister(adapter, modelInfo, wlm)
.whenCompleteAsync(
(o, t) -> {
if (o != null) {
if (o.getCode() < 300) {
modelInfo.unregisterAdapter(adapter.getName(), adapterAlias);
if (o.getCode() >= 300) {
throw new BadRequestException(o.getCode(), o.getMessage());
}
handleOutput(o, ctx);
modelInfo.unregisterAdapter(adapterName);
sendOutput(o, ctx);
}
})
.exceptionally(
Expand All @@ -371,18 +387,15 @@ private ModelInfo<Input, Output> getModelInfo(WorkerPool<Input, Output> wp) {
return (ModelInfo<Input, Output>) wp.getWpc();
}

private void handleOutput(Output output, ChannelHandlerContext ctx) {
private void sendOutput(Output output, ChannelHandlerContext ctx) {
if (ctx == null) {
return;
}

int code = output.getCode();
if (code >= 300) {
throw new BadRequestException(output.getCode(), output.getMessage());
}

NettyUtils.sendJsonResponse(
ctx, new StatusResponse(output.getMessage()), HttpResponseStatus.valueOf(code));
ctx,
new StatusResponse(output.getMessage()),
HttpResponseStatus.valueOf(output.getCode()));
}

private void onException(Throwable t, ChannelHandlerContext ctx) {
Expand Down
50 changes: 27 additions & 23 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

import ai.djl.engine.Engine;
import ai.djl.modality.Classifications.Classification;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.serving.http.*;
import ai.djl.serving.http.DescribeAdapterResponse;
import ai.djl.serving.http.DescribeWorkflowResponse;
import ai.djl.serving.http.DescribeWorkflowResponse.Model;
import ai.djl.serving.http.ErrorResponse;
import ai.djl.serving.http.ServerStartupException;
import ai.djl.serving.http.StatusResponse;
import ai.djl.serving.http.list.ListAdaptersResponse;
import ai.djl.serving.http.list.ListModelsResponse;
import ai.djl.serving.http.list.ListWorkflowsResponse;
Expand All @@ -35,9 +37,6 @@
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.Connector;
import ai.djl.serving.util.ModelStore;
import ai.djl.serving.wlm.Adapter;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkerPool;
import ai.djl.serving.wlm.util.EventManager;
import ai.djl.serving.wlm.util.ModelServerListenerAdapter;
import ai.djl.util.JsonUtils;
Expand Down Expand Up @@ -1036,7 +1035,6 @@ private void testRegisterAdapterModelNotFound() throws InterruptedException {
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1055,10 +1053,15 @@ private void testRegisterAdapterHandlerError() throws InterruptedException {
assertHttpCode(HttpResponseStatus.FAILED_DEPENDENCY.code());

// Assert adapters not added
WorkerPool<Input, Output> wp =
ModelManager.getInstance().getWorkLoadManager().getWorkerPoolById(modelName);
ModelInfo<Input, Output> modelInfo = (ModelInfo<Input, Output>) wp.getWpc();
assertNull(modelInfo.getAdapter(adapterName));
channel = connect(Connector.ConnectorType.MANAGEMENT);
assertNotNull(channel);

url = strModelPrefix + "/adapters";
request(channel, HttpMethod.GET, url);
assertHttpOk();

ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class);
assertFalse(resp.getAdapters().stream().anyMatch(a -> "adaptable2".equals(a.getName())));
}

private void testUpdateAdapter(Channel channel, boolean modelPrefix)
Expand All @@ -1084,7 +1087,6 @@ private void testUpdateAdapterModelNotFound() throws InterruptedException {
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1098,7 +1100,6 @@ private void testUpdateAdapterNotFound() throws InterruptedException {
request(channel, HttpMethod.POST, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1117,11 +1118,17 @@ private void testUpdateAdapterHandlerError() throws InterruptedException {
assertHttpCode(HttpResponseStatus.FAILED_DEPENDENCY.code());

// Assert adapters not updated
WorkerPool<Input, Output> wp =
ModelManager.getInstance().getWorkLoadManager().getWorkerPoolById(modelName);
ModelInfo<Input, Output> modelInfo = (ModelInfo<Input, Output>) wp.getWpc();
Adapter<Input, Output> adapter = modelInfo.getAdapter(adapterName);
assertEquals("src", adapter.getSrc());
channel = connect(Connector.ConnectorType.MANAGEMENT);
assertNotNull(channel);

url = strModelPrefix + "/adapters/" + adapterName;
request(channel, HttpMethod.GET, url);
assertHttpOk();

DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
assertEquals(resp.getName(), adapterName);
assertEquals(resp.getSrc(), "src");
}

private void testAdapterMissing() throws InterruptedException {
Expand Down Expand Up @@ -1246,6 +1253,7 @@ private void testListAdapter(Channel channel, boolean modelPrefix) throws Interr
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters";
request(channel, HttpMethod.GET, url);
assertHttpOk();

ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class);
assertTrue(resp.getAdapters().stream().anyMatch(a -> "adaptable".equals(a.getName())));
Expand All @@ -1261,7 +1269,6 @@ private void testListAdapterModelNotFound() throws InterruptedException {
request(channel, HttpMethod.GET, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1271,6 +1278,7 @@ private void testDescribeAdapter(Channel channel, boolean modelPrefix)
String strModelPrefix = modelPrefix ? "/models/adaptecho" : "";
String url = strModelPrefix + "/adapters/adaptable";
request(channel, HttpMethod.GET, url);
assertHttpOk();

DescribeAdapterResponse resp =
JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class);
Expand All @@ -1288,7 +1296,6 @@ private void testDescribeAdapterModelNotFound() throws InterruptedException {
request(channel, HttpMethod.GET, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1302,7 +1309,6 @@ private void testDescribeAdapterNotFound() throws InterruptedException {
request(channel, HttpMethod.GET, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand Down Expand Up @@ -1340,7 +1346,6 @@ private void testUnregisterAdapterModelNotFound() throws InterruptedException {
request(channel, HttpMethod.DELETE, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand All @@ -1354,7 +1359,6 @@ private void testUnregisterAdapterNotFound() throws InterruptedException {
request(channel, HttpMethod.DELETE, url);
channel.closeFuture().sync();
channel.close().sync();

assertHttpCode(HttpResponseStatus.NOT_FOUND.code());
}

Expand Down
6 changes: 2 additions & 4 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -639,14 +639,12 @@ public void updateAdapter(Adapter<I, O> adapter) {
* @param name the adapter to remove
* @return the removed adapter
*/
public Adapter<I, O> unregisterAdapter(String name, String alias) {
public Adapter<I, O> unregisterAdapter(String name) {
synchronized (this) {
// TODO: Remove from current workers
if (!adapters.containsKey(name)) {
throw new NoSuchElementException(
"The adapter "
+ (alias == null ? name : alias)
+ " was not found and therefore cannot be unregistered");
"The adapter was not found and therefore cannot be unregistered");
}
return adapters.remove(name);
}
Expand Down

0 comments on commit 1d29e0a

Please sign in to comment.