Skip to content

Commit

Permalink
Ensure clean thread context in MasterService (elastic#114512)
Browse files Browse the repository at this point in the history
`ThreadContext#stashContext` doesn't guarantee to give a clean thread
context, but it's important we don't allow the callers' thread contexts
to leak into the cluster state update. This commit captures the desired
thread context at startup rather than using `stashContext` when forking
the processor.
  • Loading branch information
DaveCTurner authored Oct 11, 2024
1 parent 2c1e023 commit 1e2b200
Show file tree
Hide file tree
Showing 24 changed files with 99 additions and 24 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/114512.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 114512
summary: Ensure clean thread context in `MasterService`
area: Cluster Coordination
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public class MasterService extends AbstractLifecycleComponent {

protected final ThreadPool threadPool;
private final TaskManager taskManager;
private final ThreadContext.StoredContext clusterStateUpdateContext;

private volatile ExecutorService threadPoolExecutor;
private final AtomicInteger totalQueueSize = new AtomicInteger();
Expand All @@ -128,6 +129,7 @@ public MasterService(Settings settings, ClusterSettings clusterSettings, ThreadP

this.threadPool = threadPool;
this.taskManager = taskManager;
this.clusterStateUpdateContext = getClusterStateUpdateContext(threadPool.getThreadContext());

final var queuesByPriorityBuilder = new EnumMap<Priority, PerPriorityQueue>(Priority.class);
for (final var priority : Priority.values()) {
Expand All @@ -137,6 +139,15 @@ public MasterService(Settings settings, ClusterSettings clusterSettings, ThreadP
this.unbatchedExecutor = new UnbatchedExecutor();
}

private static ThreadContext.StoredContext getClusterStateUpdateContext(ThreadContext threadContext) {
try (var ignored = threadContext.newStoredContext()) {
// capture the context in which to run all cluster state updates here where we know it to be very clean
assert threadContext.isDefaultContext() : "must only create MasterService in a clean ThreadContext";
threadContext.markAsSystemContext();
return threadContext.newStoredContext();
}
}

private void setSlowTaskLoggingThreshold(TimeValue slowTaskLoggingThreshold) {
this.slowTaskLoggingThreshold = slowTaskLoggingThreshold;
}
Expand Down Expand Up @@ -1324,8 +1335,8 @@ private void forkQueueProcessor() {

assert totalQueueSize.get() > 0;
final var threadContext = threadPool.getThreadContext();
try (var ignored = threadContext.stashContext()) {
threadContext.markAsSystemContext();
try (var ignored = threadContext.newStoredContext()) {
clusterStateUpdateContext.restore();
threadPoolExecutor.execute(queuesProcessor);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,14 @@ static NodeConstruction prepareConstruction(
constructor.loadLoggingDataProviders();
TelemetryProvider telemetryProvider = constructor.createTelemetryProvider(settings);
ThreadPool threadPool = constructor.createThreadPool(settings, telemetryProvider.getMeterRegistry());
SettingsModule settingsModule = constructor.validateSettings(initialEnvironment.settings(), settings, threadPool);

final SettingsModule settingsModule;
try (var ignored = threadPool.getThreadContext().newStoredContext()) {
// If any deprecated settings are in use then we add warnings to the thread context response headers, but we're not
// computing a response here so these headers aren't relevant and eventually just get dropped after possibly leaking into
// places they shouldn't. Best to explicitly drop them now to protect against such leakage.
settingsModule = constructor.validateSettings(initialEnvironment.settings(), settings, threadPool);
}

SearchModule searchModule = constructor.createSearchModule(settingsModule.getSettings(), threadPool, telemetryProvider);
constructor.createClientAndRegistries(settingsModule.getSettings(), threadPool, searchModule);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexVersion;
Expand Down Expand Up @@ -63,6 +64,7 @@
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class MetadataRolloverServiceTests extends ESTestCase {

Expand Down Expand Up @@ -833,6 +835,7 @@ public void testRolloverClusterStateForDataStreamNoTemplate() throws Exception {
final TestTelemetryPlugin telemetryPlugin = new TestTelemetryPlugin();

ThreadPool testThreadPool = mock(ThreadPool.class);
when(testThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
MetadataRolloverService rolloverService = DataStreamTestHelper.getMetadataRolloverService(
dataStream,
testThreadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.ingest.FakeProcessor;
import org.elasticsearch.ingest.IngestInfo;
Expand Down Expand Up @@ -81,6 +82,7 @@ public void setup() {
threadPool = mock(ThreadPool.class);
when(threadPool.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));

Client client = mock(Client.class);
ingestService = new IngestService(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.elasticsearch.indices.EmptySystemIndices;
import org.elasticsearch.injection.guice.ModuleTestCase;
import org.elasticsearch.plugins.ClusterPlugin;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.telemetry.TelemetryProvider;
import org.elasticsearch.test.gateway.TestGatewayAllocator;
import org.elasticsearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -88,8 +87,8 @@ public void setUp() throws Exception {
clusterService = new ClusterService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
null,
(TaskManager) null
threadPool,
null
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.env.Environment;
import org.elasticsearch.health.node.selection.HealthNodeTaskExecutor;
Expand Down Expand Up @@ -77,6 +78,7 @@
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class MetadataIndexTemplateServiceTests extends ESSingleNodeTestCase {

Expand Down Expand Up @@ -2473,6 +2475,7 @@ public void testAddIndexTemplateWithDeprecatedComponentTemplate() throws Excepti

private static List<Throwable> putTemplate(NamedXContentRegistry xContentRegistry, PutRequest request) {
ThreadPool testThreadPool = mock(ThreadPool.class);
when(testThreadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ClusterService clusterService = ClusterServiceUtils.createClusterService(testThreadPool);
MetadataCreateIndexService createIndexService = new MetadataCreateIndexService(
Settings.EMPTY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -263,7 +264,15 @@ public void testThreadContext() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);

try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");

final var expectedHeaders = new HashMap<String, String>();
expectedHeaders.put(randomIdentifier(), randomIdentifier());
for (final var copiedHeader : Task.HEADERS_TO_COPY) {
if (randomBoolean()) {
expectedHeaders.put(copiedHeader, randomIdentifier());
}
}

final Map<String, List<String>> expectedResponseHeaders = Collections.singletonMap(
"testResponse",
Collections.singletonList("testResponse")
Expand Down Expand Up @@ -1343,7 +1352,6 @@ public void testAcking() {
.build();
final var deterministicTaskQueue = new DeterministicTaskQueue();
final var threadPool = deterministicTaskQueue.getThreadPool();
threadPool.getThreadContext().markAsSystemContext();
try (
var masterService = createMasterService(
true,
Expand All @@ -1352,6 +1360,7 @@ public void testAcking() {
new StoppableExecutorServiceWrapper(threadPool.generic())
)
) {
threadPool.getThreadContext().markAsSystemContext();

final var responseHeaderName = "test-response-header";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsModule;
import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.env.Environment;
Expand Down Expand Up @@ -81,7 +82,6 @@
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.index.IndexVersionUtils;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.transport.RemoteClusterAware;
Expand Down Expand Up @@ -432,8 +432,8 @@ private static class ServiceHolder implements Closeable {
ClusterService clusterService = new ClusterService(
Settings.EMPTY,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
null,
(TaskManager) null
new DeterministicTaskQueue().getThreadPool(),
null
);

client = (Client) Proxy.newProxyInstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.datastreams.DataStreamsPlugin;
import org.elasticsearch.index.mapper.extras.MapperExtrasPlugin;
Expand Down Expand Up @@ -159,6 +160,7 @@ protected <T> T blockingCall(Consumer<ActionListener<T>> function) throws Except

protected static ThreadPool mockThreadPool() {
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ExecutorService executor = mock(ExecutorService.class);
doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[0]).run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
Expand Down Expand Up @@ -112,6 +113,7 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
public void setUpVariables() {
ThreadPool tp = mock(ThreadPool.class);
when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
client = mock(Client.class);
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
ClusterSettings clusterSettings = new ClusterSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.PipelineConfiguration;
Expand Down Expand Up @@ -88,6 +89,7 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
public void setUpVariables() {
ThreadPool tp = mock(ThreadPool.class);
when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
client = mock(Client.class);
Settings settings = Settings.builder().put("node.name", "InferenceProcessorFactoryTests_node").build();
ClusterSettings clusterSettings = new ClusterSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
Expand Down Expand Up @@ -427,6 +428,7 @@ private static <Response> Answer<Response> withResponse(Response response) {

private ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) {
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ClusterSettings clusterSettings = new ClusterSettings(
Settings.EMPTY,
new HashSet<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.shard.ShardId;
Expand Down Expand Up @@ -90,6 +91,7 @@ public class OpenJobPersistentTasksExecutorTests extends ESTestCase {
public void setUpMocks() {
ThreadPool tp = mock(ThreadPool.class);
when(tp.generic()).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
Settings settings = Settings.builder().put("node.name", "OpenJobPersistentTasksExecutorTests").build();
ClusterSettings clusterSettings = new ClusterSettings(
settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.datastreams.DataStreamsPlugin;
import org.elasticsearch.health.node.selection.HealthNode;
Expand Down Expand Up @@ -284,6 +285,7 @@ public void cleanup() throws Exception {

protected static ThreadPool mockThreadPool() {
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ExecutorService executor = mock(ExecutorService.class);
doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[0]).run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
Expand Down Expand Up @@ -428,6 +429,7 @@ private static <Response> Answer<Response> withFailure(Exception failure) {

public static ResultsPersisterService buildResultsPersisterService(OriginSettingClient client) {
ThreadPool tp = mock(ThreadPool.class);
when(tp.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ClusterSettings clusterSettings = new ClusterSettings(
Settings.EMPTY,
new HashSet<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ public void setup() throws Exception {
.put("path.home", createTempDir())
.build();
final ThreadContext threadContext = new ThreadContext(settings);
final var defaultContext = threadContext.newStoredContext();
final ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
AuthenticationTestHelper.builder()
Expand Down Expand Up @@ -174,7 +175,11 @@ public void setup() throws Exception {
when(securityIndex.isAvailable(SecurityIndexManager.Availability.SEARCH_SHARDS)).thenReturn(true);
when(securityIndex.defensiveCopy()).thenReturn(securityIndex);

final ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);
final ClusterService clusterService;
try (var ignored = threadContext.newStoredContext()) {
defaultContext.restore();
clusterService = ClusterServiceUtils.createClusterService(threadPool);
}

final MockLicenseState licenseState = mock(MockLicenseState.class);
when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public void setup() throws Exception {

this.threadPool = new TestThreadPool("saml test thread pool", settings);
final ThreadContext threadContext = threadPool.getThreadContext();
final var defaultContext = threadContext.newStoredContext();
AuthenticationTestHelper.builder()
.user(new User("kibana"))
.realmRef(new RealmRef("realm", "type", "node"))
Expand Down Expand Up @@ -278,7 +279,11 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
final MockLicenseState licenseState = mock(MockLicenseState.class);
when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true);

final ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool);
final ClusterService clusterService;
try (var ignored = threadContext.newStoredContext()) {
defaultContext.restore();
clusterService = ClusterServiceUtils.createClusterService(threadPool);
}
final SecurityContext securityContext = new SecurityContext(settings, threadContext);
tokenService = new TokenService(
settings,
Expand Down
Loading

0 comments on commit 1e2b200

Please sign in to comment.