From ae4d7c3fb248f3da7975c3e72bcb529599c94cfb Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 19 Mar 2024 21:25:35 +0000 Subject: [PATCH] Suggestions --- .../TransportGetTrainedModelsStatsAction.java | 208 ++++++++++-------- 1 file changed, 116 insertions(+), 92 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 416f9080f5a39..1f0506864de4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestParameters; @@ -17,8 +18,8 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.TransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; @@ -27,7 +28,6 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; import org.elasticsearch.common.util.Maps; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.query.QueryBuilder; @@ -76,7 +76,7 @@ import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor.pipelineIdsByResource; -public class TransportGetTrainedModelsStatsAction extends HandledTransportAction< +public class TransportGetTrainedModelsStatsAction extends TransportAction< GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> { @@ -96,13 +96,7 @@ public TransportGetTrainedModelsStatsAction( TrainedModelProvider trainedModelProvider, Client client ) { - super( - GetTrainedModelsStatsAction.NAME, - transportService, - actionFilters, - GetTrainedModelsStatsAction.Request::new, - EsExecutors.DIRECT_EXECUTOR_SERVICE - ); + super(GetTrainedModelsStatsAction.NAME, actionFilters, transportService.getTaskManager()); this.client = client; this.clusterService = clusterService; this.trainedModelProvider = trainedModelProvider; @@ -114,6 +108,15 @@ protected void doExecute( Task task, GetTrainedModelsStatsAction.Request request, ActionListener listener + ) { + // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can + executor.execute(ActionRunnable.wrap(listener, l -> doExecuteForked(task, request, l))); + } + + protected void doExecuteForked( + Task task, + GetTrainedModelsStatsAction.Request request, + ActionListener listener ) { final TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId()); final ModelAliasMetadata modelAliasMetadata = ModelAliasMetadata.fromState(clusterService.state()); @@ -122,90 +125,111 @@ protected void doExecute( GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); - SubscribableListener.newForked(l -> { - // When the request resource is a deployment find the - // model used in that deployment for the model stats - String idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata); - l.onResponse(idExpression); - }).>>>andThen(executor, null, (l, idExpression) -> { - logger.debug("Expanded models/deployment Ids request [{}]", idExpression); - - // the request id may contain deployment ids - // It is not an error if these don't match a model id but - // they need to be included in case the deployment id is also - // a model id. Hence, the `matchedDeploymentIds` parameter - trainedModelProvider.expandIds( - idExpression, - request.isAllowNoResources(), - request.getPageParams(), - Collections.emptySet(), - modelAliasMetadata, - parentTaskId, - matchedDeploymentIds, - l - ); - }).andThen((l, tuple) -> { - responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); - executeAsyncWithOrigin( - client, - ML_ORIGIN, - TransportNodesStatsAction.TYPE, - nodeStatsRequest(clusterService.state(), parentTaskId), - l - ); - }).>andThen(executor, null, (l, nodesStatsResponse) -> { - // find all pipelines whether using the model id, - // alias or deployment id. - Set allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases() - .entrySet() - .stream() - .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey()))) - .collect(Collectors.toSet()); - allPossiblePipelineReferences.addAll(matchedDeploymentIds); - - Map> pipelineIdsByResource = pipelineIdsByResource(clusterService.state(), allPossiblePipelineReferences); - Map modelIdIngestStats = inferenceIngestStatsByModelId( - nodesStatsResponse, - modelAliasMetadata, - pipelineIdsByResource - ); - responseBuilder.setIngestStatsByModelId(modelIdIngestStats); - trainedModelProvider.getInferenceStats( - responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]), - parentTaskId, - l - ); - }).andThen(executor, null, (l, inferenceStats) -> { - // inference stats are per model and are only - // persisted for boosted tree models - responseBuilder.setInferenceStatsByModelId( - inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())) - ); - getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l); - }).>andThen(executor, null, (l, deploymentStats) -> { - // deployment stats for each matching deployment - // not necessarily for all models - responseBuilder.setDeploymentStatsByDeploymentId( - deploymentStats.getStats() - .results() + SubscribableListener + + .>>>newForked(l -> { + // When the request resource is a deployment find the model used in that deployment for the model stats + final var idExpression = addModelsUsedInMatchingDeployments(request.getResourceId(), assignmentMetadata); + + logger.debug("Expanded models/deployment Ids request [{}]", idExpression); + + // the request id may contain deployment ids + // It is not an error if these don't match a model id but + // they need to be included in case the deployment id is also + // a model id. Hence, the `matchedDeploymentIds` parameter + trainedModelProvider.expandIds( + idExpression, + request.isAllowNoResources(), + request.getPageParams(), + Collections.emptySet(), + modelAliasMetadata, + parentTaskId, + matchedDeploymentIds, + l + ); + }) + .andThenAccept(tuple -> responseBuilder.setExpandedModelIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1())) + + .andThen( + (l, ignored) -> executeAsyncWithOrigin( + client, + ML_ORIGIN, + TransportNodesStatsAction.TYPE, + nodeStatsRequest(clusterService.state(), parentTaskId), + l + ) + ) + .>andThen(executor, null, (l, nodesStatsResponse) -> { + // find all pipelines whether using the model id, alias or deployment id. + Set allPossiblePipelineReferences = responseBuilder.getExpandedModelIdsWithAliases() + .entrySet() .stream() - .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) - ); + .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey()))) + .collect(Collectors.toSet()); + allPossiblePipelineReferences.addAll(matchedDeploymentIds); - int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); - modelSizeStats( - responseBuilder.getExpandedModelIdsWithAliases(), - request.isAllowNoResources(), - parentTaskId, - l, - numberOfAllocations - ); - }).andThen((l, modelSizeStatsByModelId) -> { - responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId); - l.onResponse( - responseBuilder.build(modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata)) - ); - }).addListener(listener, executor, null); + Map> pipelineIdsByResource = pipelineIdsByResource( + clusterService.state(), + allPossiblePipelineReferences + ); + Map modelIdIngestStats = inferenceIngestStatsByModelId( + nodesStatsResponse, + modelAliasMetadata, + pipelineIdsByResource + ); + responseBuilder.setIngestStatsByModelId(modelIdIngestStats); + trainedModelProvider.getInferenceStats( + responseBuilder.getExpandedModelIdsWithAliases().keySet().toArray(new String[0]), + parentTaskId, + l + ); + }) + .andThenAccept( + // inference stats are per model and are only persisted for boosted tree models + inferenceStats -> responseBuilder.setInferenceStatsByModelId( + inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())) + ) + ) + + .andThen( + executor, + null, + (l, ignored) -> getDeploymentStats(client, request.getResourceId(), parentTaskId, assignmentMetadata, l) + ) + .andThenApply(deploymentStats -> { + // deployment stats for each matching deployment not necessarily for all models + responseBuilder.setDeploymentStatsByDeploymentId( + deploymentStats.getStats() + .results() + .stream() + .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) + ); + return deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); + }) + + .>andThen( + executor, + null, + (l, numberOfAllocations) -> modelSizeStats( + responseBuilder.getExpandedModelIdsWithAliases(), + request.isAllowNoResources(), + parentTaskId, + l.map(modelSizeStatsByModelId -> { + responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId); + return null; + }), + numberOfAllocations + ) + ) + .andThenAccept(responseBuilder::setModelSizeStatsByModelId) + + .andThenApply( + ignored -> responseBuilder.build( + modelToDeployments(responseBuilder.getExpandedModelIdsWithAliases().keySet(), assignmentMetadata) + ) + ) + + .addListener(listener, executor, null); } static String addModelsUsedInMatchingDeployments(String idExpression, TrainedModelAssignmentMetadata assignmentMetadata) {