Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch stashing thread context back to only around system index calls #859

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811))
* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827))
* Set gradle dependency scope for common-utils to testFixturesImplementation ([#844](https://github.com/opensearch-project/k-NN/pull/844))
* Add support of .opensearch-knn-model as system index to transport actions ([#847](https://github.com/opensearch-project/k-NN/pull/847))
* Add github action for secure integ tests ([#836](https://github.com/opensearch-project/k-NN/pull/836))
* Add client setting to ignore warning exceptions ([#850](https://github.com/opensearch-project/k-NN/pull/850))
### Documentation
Expand Down

This file was deleted.

144 changes: 85 additions & 59 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import lombok.SneakyThrows;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand Down Expand Up @@ -42,19 +43,19 @@
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

import java.io.IOException;
import java.net.URL;
Expand All @@ -64,6 +65,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.function.Supplier;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
Expand Down Expand Up @@ -216,14 +218,21 @@ public void create(ActionListener<CreateIndexResponse> actionListener) throws IO
if (isCreated()) {
return;
}
CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
client.admin().indices().create(request, actionListener);
runWithStashedThreadContext(() -> {
CreateIndexRequest request;
try {
request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping())
.settings(
Settings.builder()
.put("index.hidden", true)
.put("index.number_of_shards", this.numberOfShards)
.put("index.number_of_replicas", this.numberOfReplicas)
);
} catch (IOException e) {
throw new RuntimeException(e);
}
client.admin().indices().create(request, actionListener);
});
}

@Override
Expand Down Expand Up @@ -293,8 +302,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);

final IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME);
indexRequestBuilder.setId(model.getModelID());
indexRequestBuilder.setSource(parameters);

Expand All @@ -304,8 +312,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
// After metadata update finishes, remove item from every node's cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
ActionListener<IndexResponse> onMetaListener;
onMetaListener = ActionListener.wrap(
indexResponse -> client.execute(
onMetaListener = ActionListener.wrap(indexResponse -> {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(model.getModelID()),
ActionListener.wrap(removeModelFromCacheResponse -> {
Expand All @@ -318,9 +326,8 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do

listener.onFailure(new RuntimeException(failureMessage));
}, listener::onFailure)
),
listener::onFailure
);
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
Expand All @@ -331,14 +338,18 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
}

// Create the model index if it does not already exist
Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener);
if (!isCreated()) {
create(
ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(onIndexListener), onIndexListener::onFailure)
ActionListener.wrap(
createIndexResponse -> ModelDao.runWithStashedThreadContext(indexModelRunnable),
onIndexListener::onFailure
)
);
return;
}

indexRequestBuilder.execute(onIndexListener);
ModelDao.runWithStashedThreadContext(indexModelRunnable);
}

private ActionListener<IndexResponse> getUpdateModelMetadataListener(
Expand All @@ -357,13 +368,14 @@ private ActionListener<IndexResponse> getUpdateModelMetadataListener(
);
}

@SneakyThrows
@Override
public Model get(String modelId) throws ExecutionException, InterruptedException {
public Model get(String modelId) {
/*
GET /<model_index>/<modelId>?_local
*/
try {
return ThreadContextHelper.runWithStashedThreadContext(client, () -> {
return ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");
GetResponse getResponse;
Expand All @@ -378,16 +390,7 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
} catch (RuntimeException runtimeException) {
// we need to use RuntimeException as container for real exception to keep signature
// of runWithStashedThreadContext generic
Throwable throwable = runtimeException.getCause();
if (throwable != null) {
if (throwable instanceof InterruptedException) {
throw (InterruptedException) throwable;
}
if (throwable instanceof ExecutionException) {
throw (ExecutionException) throwable;
}
}
throw runtimeException;
throw runtimeException.getCause();
}
}

Expand All @@ -402,20 +405,22 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
/*
GET /<model_index>/<modelId>?_local
*/
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));
ModelDao.runWithStashedThreadContext(() -> {
GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId)
.setPreference("_local");

getRequestBuilder.execute(ActionListener.wrap(response -> {
if (response.isSourceEmpty()) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
final Map<String, Object> responseMap = response.getSourceAsMap();
Model model = Model.getModelFromSourceMap(responseMap);
actionListener.onResponse(new GetModelResponse(model));

}, actionListener::onFailure));
}, actionListener::onFailure));
});
}

/**
Expand All @@ -426,7 +431,7 @@ public void get(String modelId, ActionListener<GetModelResponse> actionListener)
*/
@Override
public void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
ModelDao.runWithStashedThreadContext(() -> {
request.indices(MODEL_INDEX_NAME);
client.search(request, actionListener);
});
Expand Down Expand Up @@ -528,15 +533,12 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
);
clearModelMetadataStep.whenComplete(acknowledgedResponse -> {
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder);
}, listener::onFailure);

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
Expand Down Expand Up @@ -591,10 +593,12 @@ private void deleteModelFromIndex(
StepListener<DeleteResponse> deleteModelFromIndexStep,
DeleteRequestBuilder deleteRequestBuilder
) {
deleteRequestBuilder.execute(
ActionListener.wrap(
deleteModelFromIndexStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, deleteModelFromIndexStep)
ModelDao.runWithStashedThreadContext(
() -> deleteRequestBuilder.execute(
ActionListener.wrap(
deleteModelFromIndexStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, deleteModelFromIndexStep)
)
)
);
}
Expand Down Expand Up @@ -676,4 +680,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache
return stringBuilder.toString();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing
*/
private static void runWithStashedThreadContext(Runnable function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
function.run();
}
}

/**
* Set the thread context to default, this is needed to allow actions on model system index
* when security plugin is enabled
* @param function supplier function that needs to be executed after thread context has been stashed, return object
*/
private static <T> T runWithStashedThreadContext(Supplier<T> function) {
try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) {
return function.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -26,23 +24,19 @@
public class DeleteModelTransportAction extends HandledTransportAction<DeleteModelRequest, DeleteModelResponse> {

private final ModelDao modelDao;
private final Client client;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters, Client client) {
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) {
super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, DeleteModelRequest request, ActionListener<DeleteModelResponse> listener) {
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
String modelID = request.getModelID();
modelDao.delete(modelID, ActionListener.wrap(listener::onResponse, e -> {
log.error(e);
listener.onFailure(e);
}));
});
String modelID = request.getModelID();
modelDao.delete(modelID, ActionListener.wrap(listener::onResponse, e -> {
log.error(e);
listener.onFailure(e);
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import org.opensearch.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.common.ThreadContextHelper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -29,20 +27,17 @@ public class GetModelTransportAction extends HandledTransportAction<GetModelRequ
private static final Logger LOG = LogManager.getLogger(GetModelTransportAction.class);
private ModelDao modelDao;

private final Client client;

@Inject
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) {
public GetModelTransportAction(TransportService transportService, ActionFilters actionFilters) {
super(GetModelAction.NAME, transportService, actionFilters, GetModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
this.client = client;
}

@Override
protected void doExecute(Task task, GetModelRequest request, ActionListener<GetModelResponse> actionListener) {
ThreadContextHelper.runWithStashedThreadContext(client, () -> {
String modelID = request.getModelID();
modelDao.get(modelID, actionListener);
});
String modelID = request.getModelID();

modelDao.get(modelID, actionListener);

}
}
Loading