From 97c1bf01ff511c4db74dc8a81045447b009bec29 Mon Sep 17 00:00:00 2001 From: Kiran Prakash Date: Wed, 7 Aug 2024 09:39:07 -0700 Subject: [PATCH] QueryGroup Resource Tracking framework and implementation (#13897) * initial code for the sandbox resource tracking and cancellation framework Signed-off-by: Kiran Prakash * Fix Failing Tests Signed-off-by: Kiran Prakash * spotless Apply Signed-off-by: Kiran Prakash * Update SandboxService.java Signed-off-by: Kiran Prakash * Update SandboxService.java Signed-off-by: Kiran Prakash * Update SandboxTask.java Signed-off-by: Kiran Prakash * Add java docs Signed-off-by: Kiran Prakash * spotless Signed-off-by: Kiran Prakash * javadocs Signed-off-by: Kiran Prakash * javadocs Signed-off-by: Kiran Prakash * java docs Signed-off-by: Kiran Prakash * Update AbstractTaskCancellation.java Signed-off-by: Kiran Prakash * Update SandboxModule.java Signed-off-by: Kiran Prakash * Some tests and stubs Signed-off-by: Kiran Prakash * spotless Signed-off-by: Kiran Prakash * :server:testingConventions Signed-off-by: Kiran Prakash * Update AbstractTaskCancellation.java Signed-off-by: Kiran Prakash * more tests Signed-off-by: Kiran Prakash * addressing comments Signed-off-by: Kiran Prakash * revert some accidentally pushed files Signed-off-by: Kiran Prakash * resolve flakiness Signed-off-by: Kiran Prakash * renaming sandbox to querygroup and adjusting code based on merged PRs Signed-off-by: Kiran Prakash * jvm to memory Signed-off-by: Kiran Prakash * missing java docs Signed-off-by: Kiran Prakash * spotless Signed-off-by: Kiran Prakash * Update CHANGELOG.md Signed-off-by: Kiran Prakash * pluck cancellation changes out of this PR Signed-off-by: Kiran Prakash * remove unused Signed-off-by: Kiran Prakash * remove cancellation related code and add more tests coverage Signed-off-by: Kiran Prakash * us only memory and not jvm Signed-off-by: Kiran Prakash * test conventions Signed-off-by: Kiran Prakash * Bring back enum Signed-off-by: Kiran Prakash * Update SearchBackpressureService.java Signed-off-by: Kiran Prakash * revert changes Signed-off-by: Kiran Prakash * revert changes Signed-off-by: Kiran Prakash * all required changes Signed-off-by: Kiran Prakash * Update CHANGELOG.md Signed-off-by: Kiran Prakash * cleanups Signed-off-by: Kiran Prakash * Delete QueryGroupService.java Signed-off-by: Kiran Prakash * cleanups Signed-off-by: Kiran Prakash * Update QueryGroupLevelResourceUsageViewTests.java Signed-off-by: Kiran Prakash * Update QueryGroupLevelResourceUsageViewTests.java Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * Update CHANGELOG.md Signed-off-by: Kiran Prakash * rebasing with latest main Signed-off-by: Kiran Prakash * remove experimental Signed-off-by: Kiran Prakash * remove queryGroupId Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * change code comments Signed-off-by: Kiran Prakash * remmove QueryGroupUsageTracker Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerService.java Signed-off-by: Kiran Prakash * remove QueryGroupTestHelpers Signed-off-by: Kiran Prakash * cleanups Signed-off-by: Kiran Prakash * remove queryGroupHelper Signed-off-by: Kiran Prakash * Update ResourceTypeTests.java Signed-off-by: Kiran Prakash * extend OpenSearchTestCase Signed-off-by: Kiran Prakash * pr comments Signed-off-by: Kiran Prakash * Update CHANGELOG.md Signed-off-by: Kiran Prakash * Update QueryGroupResourceUsageTrackerServiceTests.java Signed-off-by: Kiran Prakash * Update ResourceTypeTests.java Signed-off-by: Kiran Prakash * Update ResourceTypeTests.java Signed-off-by: Kiran Prakash * Update ResourceType.java Signed-off-by: Kiran Prakash * Update ResourceType.java Signed-off-by: Kiran Prakash --------- Signed-off-by: Kiran Prakash --- CHANGELOG.md | 1 + .../opensearch/cluster/metadata/Metadata.java | 2 +- .../org/opensearch/search/ResourceType.java | 21 ++- .../wlm/QueryGroupLevelResourceUsageView.java | 50 +++++++ ...QueryGroupResourceUsageTrackerService.java | 84 ++++++++++++ .../opensearch/wlm/tracker/package-info.java | 12 ++ .../opensearch/search/ResourceTypeTests.java | 52 ++++++++ ...QueryGroupLevelResourceUsageViewTests.java | 64 +++++++++ ...GroupResourceUsageTrackerServiceTests.java | 126 ++++++++++++++++++ 9 files changed, 408 insertions(+), 4 deletions(-) create mode 100644 server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java create mode 100644 server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java create mode 100644 server/src/main/java/org/opensearch/wlm/tracker/package-info.java create mode 100644 server/src/test/java/org/opensearch/search/ResourceTypeTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java create mode 100644 server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index be5e5598b09c2..3e83a5bf9b4cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add ThreadContextPermission for stashAndMergeHeaders and stashWithOrigin ([#15039](https://github.com/opensearch-project/OpenSearch/pull/15039)) - [Concurrent Segment Search] Support composite aggregations with scripting ([#15072](https://github.com/opensearch-project/OpenSearch/pull/15072)) - Add `rangeQuery` and `regexpQuery` for `constant_keyword` field type ([#14711](https://github.com/opensearch-project/OpenSearch/pull/14711)) +- [Workload Management] QueryGroup resource tracking framework changes ([#13897](https://github.com/opensearch-project/OpenSearch/pull/13897)) ### Dependencies - Bump `netty` from 4.1.111.Final to 4.1.112.Final ([#15081](https://github.com/opensearch-project/OpenSearch/pull/15081)) diff --git a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java index 440b9e267cf0a..09bef2ddf9ee6 100644 --- a/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/opensearch/cluster/metadata/Metadata.java @@ -1391,7 +1391,7 @@ public Builder put(final QueryGroup queryGroup) { return queryGroups(existing); } - private Map getQueryGroups() { + public Map getQueryGroups() { return Optional.ofNullable(this.customs.get(QueryGroupMetadata.TYPE)) .map(o -> (QueryGroupMetadata) o) .map(QueryGroupMetadata::queryGroups) diff --git a/server/src/main/java/org/opensearch/search/ResourceType.java b/server/src/main/java/org/opensearch/search/ResourceType.java index fe5ce4dd2bb50..0cba2222a6e20 100644 --- a/server/src/main/java/org/opensearch/search/ResourceType.java +++ b/server/src/main/java/org/opensearch/search/ResourceType.java @@ -10,21 +10,26 @@ import org.opensearch.common.annotation.PublicApi; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.tasks.Task; import java.io.IOException; +import java.util.function.Function; /** * Enum to hold the resource type */ @PublicApi(since = "2.x") public enum ResourceType { - CPU("cpu"), - MEMORY("memory"); + CPU("cpu", task -> task.getTotalResourceUtilization(ResourceStats.CPU)), + MEMORY("memory", task -> task.getTotalResourceUtilization(ResourceStats.MEMORY)); private final String name; + private final Function getResourceUsage; - ResourceType(String name) { + ResourceType(String name, Function getResourceUsage) { this.name = name; + this.getResourceUsage = getResourceUsage; } /** @@ -48,4 +53,14 @@ public static void writeTo(StreamOutput out, ResourceType resourceType) throws I public String getName() { return name; } + + /** + * Gets the resource usage for a given resource type and task. + * + * @param task the task for which to calculate resource usage + * @return the resource usage + */ + public long getResourceUsage(Task task) { + return getResourceUsage.apply(task); + } } diff --git a/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java new file mode 100644 index 0000000000000..2fd743dc3f83f --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/QueryGroupLevelResourceUsageView.java @@ -0,0 +1,50 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; + +import java.util.List; +import java.util.Map; + +/** + * Represents the point in time view of resource usage of a QueryGroup and + * has a 1:1 relation with a QueryGroup. + * This class holds the resource usage data and the list of active tasks. + */ +public class QueryGroupLevelResourceUsageView { + // resourceUsage holds the resource usage data for a QueryGroup at a point in time + private final Map resourceUsage; + // activeTasks holds the list of active tasks for a QueryGroup at a point in time + private final List activeTasks; + + public QueryGroupLevelResourceUsageView(Map resourceUsage, List activeTasks) { + this.resourceUsage = resourceUsage; + this.activeTasks = activeTasks; + } + + /** + * Returns the resource usage data. + * + * @return The map of resource usage data + */ + public Map getResourceUsageData() { + return resourceUsage; + } + + /** + * Returns the list of active tasks. + * + * @return The list of active tasks + */ + public List getActiveTasks() { + return activeTasks; + } +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java new file mode 100644 index 0000000000000..bfbf5d8a452d1 --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerService.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; + +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * This class tracks resource usage per QueryGroup + */ +public class QueryGroupResourceUsageTrackerService { + + public static final EnumSet TRACKED_RESOURCES = EnumSet.allOf(ResourceType.class); + private final TaskResourceTrackingService taskResourceTrackingService; + + /** + * QueryGroupResourceTrackerService constructor + * + * @param taskResourceTrackingService Service that helps track resource usage of tasks running on a node. + */ + public QueryGroupResourceUsageTrackerService(TaskResourceTrackingService taskResourceTrackingService) { + this.taskResourceTrackingService = taskResourceTrackingService; + } + + /** + * Constructs a map of QueryGroupLevelResourceUsageView instances for each QueryGroup. + * + * @return Map of QueryGroup views + */ + public Map constructQueryGroupLevelUsageViews() { + final Map> tasksByQueryGroup = getTasksGroupedByQueryGroup(); + final Map queryGroupViews = new HashMap<>(); + + // Iterate over each QueryGroup entry + for (Map.Entry> queryGroupEntry : tasksByQueryGroup.entrySet()) { + // Compute the QueryGroup usage + final EnumMap queryGroupUsage = new EnumMap<>(ResourceType.class); + for (ResourceType resourceType : TRACKED_RESOURCES) { + long queryGroupResourceUsage = 0; + for (Task task : queryGroupEntry.getValue()) { + queryGroupResourceUsage += resourceType.getResourceUsage(task); + } + queryGroupUsage.put(resourceType, queryGroupResourceUsage); + } + + // Add to the QueryGroup View + queryGroupViews.put( + queryGroupEntry.getKey(), + new QueryGroupLevelResourceUsageView(queryGroupUsage, queryGroupEntry.getValue()) + ); + } + return queryGroupViews; + } + + /** + * Groups tasks by their associated QueryGroup. + * + * @return Map of tasks grouped by QueryGroup + */ + private Map> getTasksGroupedByQueryGroup() { + return taskResourceTrackingService.getResourceAwareTasks() + .values() + .stream() + .filter(QueryGroupTask.class::isInstance) + .map(QueryGroupTask.class::cast) + .collect(Collectors.groupingBy(QueryGroupTask::getQueryGroupId, Collectors.mapping(task -> (Task) task, Collectors.toList()))); + } +} diff --git a/server/src/main/java/org/opensearch/wlm/tracker/package-info.java b/server/src/main/java/org/opensearch/wlm/tracker/package-info.java new file mode 100644 index 0000000000000..86efc99355d3d --- /dev/null +++ b/server/src/main/java/org/opensearch/wlm/tracker/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * QueryGroup resource tracking artifacts + */ +package org.opensearch.wlm.tracker; diff --git a/server/src/test/java/org/opensearch/search/ResourceTypeTests.java b/server/src/test/java/org/opensearch/search/ResourceTypeTests.java new file mode 100644 index 0000000000000..78827b8b1bdad --- /dev/null +++ b/server/src/test/java/org/opensearch/search/ResourceTypeTests.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.test.OpenSearchTestCase; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ResourceTypeTests extends OpenSearchTestCase { + + public void testFromName() { + assertSame(ResourceType.CPU, ResourceType.fromName("cpu")); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("CPU"); }); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("Cpu"); }); + + assertSame(ResourceType.MEMORY, ResourceType.fromName("memory")); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("Memory"); }); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("MEMORY"); }); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("JVM"); }); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("Heap"); }); + assertThrows(IllegalArgumentException.class, () -> { ResourceType.fromName("Disk"); }); + } + + public void testGetName() { + assertEquals("cpu", ResourceType.CPU.getName()); + assertEquals("memory", ResourceType.MEMORY.getName()); + } + + public void testGetResourceUsage() { + SearchShardTask mockTask = createMockTask(SearchShardTask.class, 100, 200); + assertEquals(100, ResourceType.CPU.getResourceUsage(mockTask)); + assertEquals(200, ResourceType.MEMORY.getResourceUsage(mockTask)); + } + + private T createMockTask(Class type, long cpuUsage, long heapUsage) { + T task = mock(type); + when(task.getTotalResourceUtilization(ResourceStats.CPU)).thenReturn(cpuUsage); + when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage); + return task; + } +} diff --git a/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java new file mode 100644 index 0000000000000..7f6419505fec2 --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/QueryGroupLevelResourceUsageViewTests.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm; + +import org.opensearch.action.search.SearchAction; +import org.opensearch.core.tasks.TaskId; +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class QueryGroupLevelResourceUsageViewTests extends OpenSearchTestCase { + Map resourceUsage; + List activeTasks; + + public void setUp() throws Exception { + super.setUp(); + resourceUsage = Map.of(ResourceType.fromName("memory"), 34L, ResourceType.fromName("cpu"), 12L); + activeTasks = List.of(getRandomTask(4321)); + } + + public void testGetResourceUsageData() { + QueryGroupLevelResourceUsageView queryGroupLevelResourceUsageView = new QueryGroupLevelResourceUsageView( + resourceUsage, + activeTasks + ); + Map resourceUsageData = queryGroupLevelResourceUsageView.getResourceUsageData(); + assertTrue(assertResourceUsageData(resourceUsageData)); + } + + public void testGetActiveTasks() { + QueryGroupLevelResourceUsageView queryGroupLevelResourceUsageView = new QueryGroupLevelResourceUsageView( + resourceUsage, + activeTasks + ); + List activeTasks = queryGroupLevelResourceUsageView.getActiveTasks(); + assertEquals(1, activeTasks.size()); + assertEquals(4321, activeTasks.get(0).getId()); + } + + private boolean assertResourceUsageData(Map resourceUsageData) { + return resourceUsageData.get(ResourceType.fromName("memory")) == 34L && resourceUsageData.get(ResourceType.fromName("cpu")) == 12L; + } + + private Task getRandomTask(long id) { + return new Task( + id, + "transport", + SearchAction.NAME, + "test description", + new TaskId(randomLong() + ":" + randomLong()), + Collections.emptyMap() + ); + } +} diff --git a/server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java b/server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java new file mode 100644 index 0000000000000..967119583c25f --- /dev/null +++ b/server/src/test/java/org/opensearch/wlm/tracker/QueryGroupResourceUsageTrackerServiceTests.java @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.wlm.tracker; + +import org.opensearch.action.search.SearchShardTask; +import org.opensearch.action.search.SearchTask; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.tasks.resourcetracker.ResourceStats; +import org.opensearch.search.ResourceType; +import org.opensearch.tasks.CancellableTask; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskResourceTrackingService; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.wlm.QueryGroupLevelResourceUsageView; +import org.opensearch.wlm.QueryGroupTask; +import org.junit.After; +import org.junit.Before; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.opensearch.wlm.QueryGroupTask.QUERY_GROUP_ID_HEADER; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QueryGroupResourceUsageTrackerServiceTests extends OpenSearchTestCase { + TestThreadPool threadPool; + TaskResourceTrackingService mockTaskResourceTrackingService; + QueryGroupResourceUsageTrackerService queryGroupResourceUsageTrackerService; + + @Before + public void setup() { + threadPool = new TestThreadPool(getTestName()); + mockTaskResourceTrackingService = mock(TaskResourceTrackingService.class); + queryGroupResourceUsageTrackerService = new QueryGroupResourceUsageTrackerService(mockTaskResourceTrackingService); + } + + @After + public void cleanup() { + ThreadPool.terminate(threadPool, 5, TimeUnit.SECONDS); + } + + public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_WhenTasksArePresent() { + List queryGroupIds = List.of("queryGroup1", "queryGroup2", "queryGroup3"); + + Map activeSearchShardTasks = createActiveSearchShardTasks(queryGroupIds); + when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); + Map stringQueryGroupLevelResourceUsageViewMap = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + + for (String queryGroupId : queryGroupIds) { + assertEquals( + 400, + (long) stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getResourceUsageData().get(ResourceType.MEMORY) + ); + assertEquals(2, stringQueryGroupLevelResourceUsageViewMap.get(queryGroupId).getActiveTasks().size()); + } + } + + public void testConstructQueryGroupLevelViews_CreatesQueryGroupLevelUsageView_WhenTasksAreNotPresent() { + Map stringQueryGroupLevelResourceUsageViewMap = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + assertTrue(stringQueryGroupLevelResourceUsageViewMap.isEmpty()); + } + + public void testConstructQueryGroupLevelUsageViews_WithTasksHavingDifferentResourceUsage() { + Map activeSearchShardTasks = new HashMap<>(); + activeSearchShardTasks.put(1L, createMockTask(SearchShardTask.class, 100, 200, "queryGroup1")); + activeSearchShardTasks.put(2L, createMockTask(SearchShardTask.class, 200, 400, "queryGroup1")); + when(mockTaskResourceTrackingService.getResourceAwareTasks()).thenReturn(activeSearchShardTasks); + + Map queryGroupViews = queryGroupResourceUsageTrackerService + .constructQueryGroupLevelUsageViews(); + + assertEquals(600, (long) queryGroupViews.get("queryGroup1").getResourceUsageData().get(ResourceType.MEMORY)); + assertEquals(2, queryGroupViews.get("queryGroup1").getActiveTasks().size()); + } + + private Map createActiveSearchShardTasks(List queryGroupIds) { + Map activeSearchShardTasks = new HashMap<>(); + long task_id = 0; + for (String queryGroupId : queryGroupIds) { + for (int i = 0; i < 2; i++) { + activeSearchShardTasks.put(++task_id, createMockTask(SearchShardTask.class, 100, 200, queryGroupId)); + } + } + return activeSearchShardTasks; + } + + private T createMockTask(Class type, long cpuUsage, long heapUsage, String queryGroupId) { + T task = mock(type); + if (task instanceof SearchTask || task instanceof SearchShardTask) { + // Stash the current thread context to ensure that any existing context is preserved and restored after setting the query group + // ID. + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + threadPool.getThreadContext().putHeader(QUERY_GROUP_ID_HEADER, queryGroupId); + ((QueryGroupTask) task).setQueryGroupId(threadPool.getThreadContext()); + } + } + when(task.getTotalResourceUtilization(ResourceStats.CPU)).thenReturn(cpuUsage); + when(task.getTotalResourceUtilization(ResourceStats.MEMORY)).thenReturn(heapUsage); + when(task.getStartTimeNanos()).thenReturn((long) 0); + + AtomicBoolean isCancelled = new AtomicBoolean(false); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(task).cancel(anyString()); + doAnswer(invocation -> isCancelled.get()).when(task).isCancelled(); + + return task; + } +}