From 91e4c9000e6f2a7d8efd20980e6f6f6c0f22b3aa Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 23 Sep 2024 10:38:02 +0800 Subject: [PATCH] Filter out remote model auto redeployment Signed-off-by: zane-neo --- .../autoredeploy/MLModelAutoReDeployer.java | 19 ++++++++----- .../MLModelAutoReDeployerTests.java | 28 +++++++++++++++++++ .../ml/autoredeploy/RemoteModelResult.json | 20 +++++++++++++ 3 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json diff --git a/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java b/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java index aa6bb743c9..0dd6d95680 100644 --- a/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java +++ b/plugin/src/main/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployer.java @@ -12,6 +12,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; @@ -30,6 +31,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; @@ -257,16 +259,19 @@ private void queryRunningModels(ActionListener listener) { private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeployArrangement) { String modelId = modelAutoRedeployArrangement.getSearchResponse().getId(); List addedNodes = modelAutoRedeployArrangement.getAddedNodes(); - List planningWorkerNodes = (List) modelAutoRedeployArrangement - .getSearchResponse() - .getSourceAsMap() + Map sourceAsMap = modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap(); + String functionName = (String) Optional.ofNullable(sourceAsMap.get(MLModel.FUNCTION_NAME_FIELD)) + .orElse(sourceAsMap.get(MLModel.ALGORITHM_FIELD)); + if (FunctionName.REMOTE == FunctionName.from(functionName)) { + log.info("Skipping redeploying remote model {} as remote model deployment can be done at prediction time.", modelId); + return; + } + List planningWorkerNodes = (List) sourceAsMap .get(MLModel.PLANNING_WORKER_NODES_FIELD); - Integer autoRedeployRetryTimes = (Integer) modelAutoRedeployArrangement - .getSearchResponse() - .getSourceAsMap() + Integer autoRedeployRetryTimes = (Integer) sourceAsMap .get(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD); Boolean deployToAllNodes = (Boolean) Optional - .ofNullable(modelAutoRedeployArrangement.getSearchResponse().getSourceAsMap().get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)) + .ofNullable(sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)) .orElse(false); // calculate node ids. String[] nodeIds = null; diff --git a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java index 5f25b96f59..b32a0593d8 100644 --- a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java @@ -609,6 +609,34 @@ public void test_redeployAModel_with_needRedeployArray_isEmpty() { mlModelAutoReDeployer.redeployAModel(); } + public void test_buildAutoReloadArrangement_skippingRemoteModel_success() throws Exception { + Settings settings = Settings + .builder() + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), true) + .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES.getKey(), 3) + .put(ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE.getKey(), true) + .put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), false) + .build(); + + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.localNode()).thenReturn(localNode); + when(clusterService.getClusterSettings()).thenReturn(getClusterSettings(settings)); + mockClusterDataNodes(clusterService); + + mlModelAutoReDeployer = spy( + new MLModelAutoReDeployer(clusterService, client, settings, mlModelManager, searchRequestBuilderFactory) + ); + + SearchResponse searchResponse = buildDeployToAllNodesTrueSearchResponse("RemoteModelResult.json"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(searchResponse); + return null; + }).when(searchRequestBuilder).execute(isA(ActionListener.class)); + mlModelAutoReDeployer.buildAutoReloadArrangement(addedNodes, clusterManagerNodeId); + verify(client, never()).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), any(ActionListener.class)); + } + private SearchResponse buildDeployToAllNodesTrueSearchResponse(String file) throws Exception { MLModel mlModel = buildModelWithJsonFile(file); return createResponseWithModel(mlModel); diff --git a/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json b/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json new file mode 100644 index 0000000000..fe7103fdee --- /dev/null +++ b/plugin/src/test/resources/org/opensearch/ml/autoredeploy/RemoteModelResult.json @@ -0,0 +1,20 @@ +{ + "last_deployed_time": 1722954415807, + "model_version": "619", + "created_time": 1722954415642, + "deploy_to_all_nodes": true, + "is_hidden": false, + "description": "This is a test model", + "model_state": "DEPLOYED", + "planning_worker_node_count": 1, + "auto_redeploy_retry_times": 0, + "last_updated_time": 1723691017054, + "name": "my sagemaker model", + "connector_id": "z3kVKJEBAfFjoGUT_Ui7", + "current_worker_node_count": 0, + "model_group_id": "MiJPJ5EBQM-QzppeWrTJ", + "planning_worker_nodes": [ + "DecGG5pDQYaqelLMLcIV9Q" + ], + "algorithm": "REMOTE" +}