Skip to content

Commit

Permalink
Make ThreadContext.markAsSystemContext package-private
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks committed Jul 26, 2024
1 parent 59302a3 commit bc19108
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.core.Assertions;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
Expand Down Expand Up @@ -142,6 +143,7 @@ public abstract class TransportReplicationAction<
public static final String REPLICA_ACTION_SUFFIX = "[r]";

protected final ThreadPool threadPool;
protected final InternalThreadContextWrapper tcWrapper;
protected final TransportService transportService;
protected final ClusterService clusterService;
protected final ShardStateAction shardStateAction;
Expand Down Expand Up @@ -239,6 +241,7 @@ protected TransportReplicationAction(
) {
super(actionName, actionFilters, transportService.getTaskManager());
this.threadPool = threadPool;
this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
this.transportService = transportService;
this.clusterService = clusterService;
this.indicesService = indicesService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -104,6 +105,7 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements

private final ClusterSettings clusterSettings;
protected final ThreadPool threadPool;
protected final InternalThreadContextWrapper tcWrapper;

private volatile TimeValue slowTaskLoggingThreshold;

Expand Down Expand Up @@ -139,6 +141,7 @@ public ClusterApplierService(
) {
this.clusterSettings = clusterSettings;
this.threadPool = threadPool;
this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
this.state = new AtomicReference<>();
this.nodeName = nodeName;

Expand Down Expand Up @@ -396,7 +399,7 @@ private void submitStateUpdateTask(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final UpdateTask updateTask = new UpdateTask(
config.priority(),
source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.opensearch.common.util.concurrent.FutureUtils;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -134,6 +135,7 @@ public class MasterService extends AbstractLifecycleComponent {
private volatile TimeValue slowTaskLoggingThreshold;

protected final ThreadPool threadPool;
protected final InternalThreadContextWrapper tcWrapper;

private volatile PrioritizedOpenSearchThreadPoolExecutor threadPoolExecutor;
private volatile Batcher taskBatcher;
Expand Down Expand Up @@ -169,6 +171,7 @@ public MasterService(
);
this.stateStats = new ClusterStateStats();
this.threadPool = threadPool;
this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext());
this.clusterManagerMetrics = clusterManagerMetrics;
}

Expand Down Expand Up @@ -1022,7 +1025,7 @@ public <T> void submitStateUpdateTasks(
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();

List<Batcher.UpdateTask> safeTasks = tasks.entrySet()
.stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.util.concurrent;

import java.util.Objects;

/**
* Wrapper around the ThreadContext to expose methods to the core repo without
* exposing them to plugins
*
* @opensearch.internal
*/
public class InternalThreadContextWrapper {
private final ThreadContext threadContext;

private InternalThreadContextWrapper(final ThreadContext threadContext) {
this.threadContext = threadContext;
}

public static InternalThreadContextWrapper from(ThreadContext threadContext) {
return new InternalThreadContextWrapper(threadContext);
}

public void markAsSystemContext() {
Objects.requireNonNull(threadContext, "threadContext cannot be null");
threadContext.markAsSystemContext();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ boolean isDefaultContext() {
* Marks this thread context as an internal system context. This signals that actions in this context are issued
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public GlobalCheckpointSyncAction(
public void updateGlobalCheckpointForShard(final ShardId shardId) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
execute(new Request(shardId), ActionListener.wrap(r -> {}, e -> {
if (ExceptionsHelper.unwrap(e, AlreadyClosedException.class, IndexShardClosedException.class) == null) {
logger.info(new ParameterizedMessage("{} global checkpoint sync failed", shardId), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ final void backgroundSync(ShardId shardId, String primaryAllocationId, long prim
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_background_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ final void sync(
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
final Request request = new Request(shardId, retentionLeases);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_sync", request);
transportService.sendChildRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ final void publish(IndexShard indexShard, ReplicationCheckpoint checkpoint) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we have to execute under the system context so that if security is enabled the sync is authorized
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
PublishCheckpointRequest request = new PublishCheckpointRequest(checkpoint);
final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "segrep_publish_checkpoint", request);
final ReplicationTimer timer = new ReplicationTimer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -71,6 +72,7 @@ final class RemoteClusterConnection implements Closeable {
private final RemoteConnectionStrategy connectionStrategy;
private final String clusterAlias;
private final ThreadPool threadPool;
private final InternalThreadContextWrapper tcWrapper;
private volatile boolean skipUnavailable;
private final TimeValue initialConnectionTimeout;

Expand All @@ -91,6 +93,7 @@ final class RemoteClusterConnection implements Closeable {
this.skipUnavailable = RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE.getConcreteSettingForNamespace(clusterAlias)
.get(settings);
this.threadPool = transportService.threadPool;
this.tcWrapper = InternalThreadContextWrapper.from(transportService.threadPool.getThreadContext());
initialConnectionTimeout = RemoteClusterService.REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings);
}

Expand Down Expand Up @@ -136,7 +139,7 @@ void collectNodes(ActionListener<Function<String, DiscoveryNode>> listener) {
new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any existing context information
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();

final ClusterStateRequest request = new ClusterStateRequest();
request.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -160,6 +161,7 @@ public Writeable.Reader<RemoteConnectionInfo.ModeInfo> getReader() {

protected final TransportService transportService;
protected final RemoteConnectionManager connectionManager;
protected final InternalThreadContextWrapper tcWrapper;
protected final String clusterAlias;

RemoteConnectionStrategy(
Expand All @@ -170,6 +172,7 @@ public Writeable.Reader<RemoteConnectionInfo.ModeInfo> getReader() {
) {
this.clusterAlias = clusterAlias;
this.transportService = transportService;
this.tcWrapper = InternalThreadContextWrapper.from(transportService.threadPool.getThreadContext());
this.connectionManager = connectionManager;
this.maxPendingConnectionListeners = REMOTE_MAX_PENDING_CONNECTION_LISTENERS.get(settings);
connectionManager.addListener(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ private void collectRemoteNodes(Iterator<Supplier<DiscoveryNode>> seedNodes, Act
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
// we stash any context here since this is an internal execution and should not leak any
// existing context information.
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
transportService.sendRequest(
connection,
ClusterStateAction.NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
Expand Down Expand Up @@ -224,8 +225,9 @@ public void testUpdateTemplates() {

service.upgradesInProgress.set(additionsCount + deletionsCount + 2); // +2 to skip tryFinishUpgrade
final ThreadContext threadContext = threadPool.getThreadContext();
final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
service.upgradeTemplates(additions, deletions);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContext.StoredContext;
import org.opensearch.telemetry.Telemetry;
Expand Down Expand Up @@ -256,11 +257,12 @@ public void run() {
public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.InternalThreadContextWrapper;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor;
import org.opensearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -133,8 +134,9 @@ public void run() {
taskInProgress = true;
scheduledNextTask = false;
final ThreadContext threadContext = threadPool.getThreadContext();
final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext);
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.markAsSystemContext();
tcWrapper.markAsSystemContext();
task.run();
}
if (waitForPublish == false) {
Expand Down

0 comments on commit bc19108

Please sign in to comment.