diff --git a/CHANGELOG.md b/CHANGELOG.md index e88a084f7d7f6..b6bf68335ae8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Bump `org.apache.commons:commons-lang3` from 3.14.0 to 3.15.0 ([#14861](https://github.com/opensearch-project/OpenSearch/pull/14861)) ### Changed +- Make ThreadContext.markAsSystemContext package-private ([#14988](https://github.com/opensearch-project/OpenSearch/pull/14988)) ### Deprecated diff --git a/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java index 49a96603f6802..cbf1633726b86 100644 --- a/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/opensearch/action/support/replication/TransportReplicationAction.java @@ -63,6 +63,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.core.Assertions; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; @@ -142,6 +143,7 @@ public abstract class TransportReplicationAction< public static final String REPLICA_ACTION_SUFFIX = "[r]"; protected final ThreadPool threadPool; + protected volatile InternalThreadContextWrapper tcWrapper; protected final TransportService transportService; protected final ClusterService clusterService; protected final ShardStateAction shardStateAction; @@ -239,6 +241,9 @@ protected TransportReplicationAction( ) { super(actionName, actionFilters, transportService.getTaskManager()); this.threadPool = threadPool; + if (threadPool != null) { + this.tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext()); + } this.transportService = transportService; this.clusterService = clusterService; this.indicesService = indicesService; diff --git a/server/src/main/java/org/opensearch/cluster/service/ClusterApplierService.java b/server/src/main/java/org/opensearch/cluster/service/ClusterApplierService.java index 6234427445754..4a2c6e9424410 100644 --- a/server/src/main/java/org/opensearch/cluster/service/ClusterApplierService.java +++ b/server/src/main/java/org/opensearch/cluster/service/ClusterApplierService.java @@ -58,6 +58,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor; import org.opensearch.common.util.concurrent.ThreadContext; @@ -104,6 +105,7 @@ public class ClusterApplierService extends AbstractLifecycleComponent implements private final ClusterSettings clusterSettings; protected final ThreadPool threadPool; + protected volatile InternalThreadContextWrapper tcWrapper; private volatile TimeValue slowTaskLoggingThreshold; @@ -173,6 +175,7 @@ protected synchronized void doStart() { Objects.requireNonNull(nodeConnectionsService, "please set the node connection service before starting"); Objects.requireNonNull(state.get(), "please set initial state before starting"); threadPoolExecutor = createThreadPoolExecutor(); + tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext()); } protected PrioritizedOpenSearchThreadPoolExecutor createThreadPoolExecutor() { @@ -396,7 +399,7 @@ private void submitStateUpdateTask( final ThreadContext threadContext = threadPool.getThreadContext(); final Supplier supplier = threadContext.newRestorableContext(true); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); final UpdateTask updateTask = new UpdateTask( config.priority(), source, diff --git a/server/src/main/java/org/opensearch/cluster/service/MasterService.java b/server/src/main/java/org/opensearch/cluster/service/MasterService.java index 4ab8255df7658..ca16036086d1b 100644 --- a/server/src/main/java/org/opensearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/opensearch/cluster/service/MasterService.java @@ -63,6 +63,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.CountDown; import org.opensearch.common.util.concurrent.FutureUtils; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor; import org.opensearch.common.util.concurrent.ThreadContext; @@ -134,6 +135,7 @@ public class MasterService extends AbstractLifecycleComponent { private volatile TimeValue slowTaskLoggingThreshold; protected final ThreadPool threadPool; + protected volatile InternalThreadContextWrapper tcWrapper; private volatile PrioritizedOpenSearchThreadPoolExecutor threadPoolExecutor; private volatile Batcher taskBatcher; @@ -190,6 +192,7 @@ protected synchronized void doStart() { Objects.requireNonNull(clusterStateSupplier, "please set a cluster state supplier before starting"); threadPoolExecutor = createThreadPoolExecutor(); taskBatcher = new Batcher(logger, threadPoolExecutor, clusterManagerTaskThrottler); + tcWrapper = InternalThreadContextWrapper.from(threadPool.getThreadContext()); } protected PrioritizedOpenSearchThreadPoolExecutor createThreadPoolExecutor() { @@ -1022,7 +1025,7 @@ public void submitStateUpdateTasks( final ThreadContext threadContext = threadPool.getThreadContext(); final Supplier supplier = threadContext.newRestorableContext(true); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); List safeTasks = tasks.entrySet() .stream() diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/InternalThreadContextWrapper.java b/server/src/main/java/org/opensearch/common/util/concurrent/InternalThreadContextWrapper.java new file mode 100644 index 0000000000000..a244e570149ba --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/concurrent/InternalThreadContextWrapper.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util.concurrent; + +import java.util.Objects; + +/** + * Wrapper around the ThreadContext to expose methods to the core repo without + * exposing them to plugins + * + * @opensearch.internal + */ +public class InternalThreadContextWrapper { + private final ThreadContext threadContext; + + private InternalThreadContextWrapper(final ThreadContext threadContext) { + this.threadContext = threadContext; + } + + public static InternalThreadContextWrapper from(ThreadContext threadContext) { + return new InternalThreadContextWrapper(threadContext); + } + + public void markAsSystemContext() { + Objects.requireNonNull(threadContext, "threadContext cannot be null"); + threadContext.markAsSystemContext(); + } +} diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java index 906a27e9f398c..7b276b7d97167 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java @@ -555,7 +555,7 @@ boolean isDefaultContext() { * Marks this thread context as an internal system context. This signals that actions in this context are issued * by the system itself rather than by a user action. */ - public void markAsSystemContext() { + void markAsSystemContext() { threadLocal.set(threadLocal.get().setSystemContext(propagators)); } diff --git a/server/src/main/java/org/opensearch/index/seqno/GlobalCheckpointSyncAction.java b/server/src/main/java/org/opensearch/index/seqno/GlobalCheckpointSyncAction.java index c6a1f5f27a875..ab26f63210e62 100644 --- a/server/src/main/java/org/opensearch/index/seqno/GlobalCheckpointSyncAction.java +++ b/server/src/main/java/org/opensearch/index/seqno/GlobalCheckpointSyncAction.java @@ -98,7 +98,7 @@ public GlobalCheckpointSyncAction( public void updateGlobalCheckpointForShard(final ShardId shardId) { final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); execute(new Request(shardId), ActionListener.wrap(r -> {}, e -> { if (ExceptionsHelper.unwrap(e, AlreadyClosedException.class, IndexShardClosedException.class) == null) { logger.info(new ParameterizedMessage("{} global checkpoint sync failed", shardId), e); diff --git a/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseBackgroundSyncAction.java b/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseBackgroundSyncAction.java index 5fa0a1a6459e7..075ecb008babe 100644 --- a/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseBackgroundSyncAction.java +++ b/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseBackgroundSyncAction.java @@ -122,7 +122,7 @@ final void backgroundSync(ShardId shardId, String primaryAllocationId, long prim final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we have to execute under the system context so that if security is enabled the sync is authorized - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); final Request request = new Request(shardId, retentionLeases); final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_background_sync", request); transportService.sendChildRequest( diff --git a/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseSyncAction.java b/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseSyncAction.java index ca3c7e1d49700..fc75426ef758a 100644 --- a/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseSyncAction.java +++ b/server/src/main/java/org/opensearch/index/seqno/RetentionLeaseSyncAction.java @@ -137,7 +137,7 @@ final void sync( final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we have to execute under the system context so that if security is enabled the sync is authorized - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); final Request request = new Request(shardId, retentionLeases); final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "retention_lease_sync", request); transportService.sendChildRequest( diff --git a/server/src/main/java/org/opensearch/indices/replication/checkpoint/PublishCheckpointAction.java b/server/src/main/java/org/opensearch/indices/replication/checkpoint/PublishCheckpointAction.java index 8f39aa194b06c..8220ff4426789 100644 --- a/server/src/main/java/org/opensearch/indices/replication/checkpoint/PublishCheckpointAction.java +++ b/server/src/main/java/org/opensearch/indices/replication/checkpoint/PublishCheckpointAction.java @@ -113,7 +113,7 @@ final void publish(IndexShard indexShard, ReplicationCheckpoint checkpoint) { final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we have to execute under the system context so that if security is enabled the sync is authorized - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); PublishCheckpointRequest request = new PublishCheckpointRequest(checkpoint); final ReplicationTask task = (ReplicationTask) taskManager.register("transport", "segrep_publish_checkpoint", request); final ReplicationTimer timer = new ReplicationTimer(); diff --git a/server/src/main/java/org/opensearch/transport/RemoteClusterConnection.java b/server/src/main/java/org/opensearch/transport/RemoteClusterConnection.java index 8a5f6dfffb036..3a47822ae0d78 100644 --- a/server/src/main/java/org/opensearch/transport/RemoteClusterConnection.java +++ b/server/src/main/java/org/opensearch/transport/RemoteClusterConnection.java @@ -39,6 +39,7 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.io.IOUtils; import org.opensearch.core.action.ActionListener; @@ -71,6 +72,7 @@ final class RemoteClusterConnection implements Closeable { private final RemoteConnectionStrategy connectionStrategy; private final String clusterAlias; private final ThreadPool threadPool; + private final InternalThreadContextWrapper tcWrapper; private volatile boolean skipUnavailable; private final TimeValue initialConnectionTimeout; @@ -91,6 +93,7 @@ final class RemoteClusterConnection implements Closeable { this.skipUnavailable = RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE.getConcreteSettingForNamespace(clusterAlias) .get(settings); this.threadPool = transportService.threadPool; + this.tcWrapper = InternalThreadContextWrapper.from(transportService.threadPool.getThreadContext()); initialConnectionTimeout = RemoteClusterService.REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings); } @@ -136,7 +139,7 @@ void collectNodes(ActionListener> listener) { new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we stash any context here since this is an internal execution and should not leak any existing context information - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); final ClusterStateRequest request = new ClusterStateRequest(); request.clear(); diff --git a/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java b/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java index f2c159d1380e8..07428ad29282a 100644 --- a/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java +++ b/server/src/main/java/org/opensearch/transport/RemoteConnectionStrategy.java @@ -43,6 +43,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.Writeable; @@ -160,6 +161,7 @@ public Writeable.Reader getReader() { protected final TransportService transportService; protected final RemoteConnectionManager connectionManager; + protected final InternalThreadContextWrapper tcWrapper; protected final String clusterAlias; RemoteConnectionStrategy( @@ -170,6 +172,7 @@ public Writeable.Reader getReader() { ) { this.clusterAlias = clusterAlias; this.transportService = transportService; + this.tcWrapper = InternalThreadContextWrapper.from(transportService.getThreadPool().getThreadContext()); this.connectionManager = connectionManager; this.maxPendingConnectionListeners = REMOTE_MAX_PENDING_CONNECTION_LISTENERS.get(settings); connectionManager.addListener(this); diff --git a/server/src/main/java/org/opensearch/transport/SniffConnectionStrategy.java b/server/src/main/java/org/opensearch/transport/SniffConnectionStrategy.java index 07ba96b135189..257f59b6c26b5 100644 --- a/server/src/main/java/org/opensearch/transport/SniffConnectionStrategy.java +++ b/server/src/main/java/org/opensearch/transport/SniffConnectionStrategy.java @@ -349,7 +349,7 @@ private void collectRemoteNodes(Iterator> seedNodes, Act try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { // we stash any context here since this is an internal execution and should not leak any // existing context information. - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); transportService.sendRequest( connection, ClusterStateAction.NAME, diff --git a/server/src/test/java/org/opensearch/cluster/metadata/TemplateUpgradeServiceTests.java b/server/src/test/java/org/opensearch/cluster/metadata/TemplateUpgradeServiceTests.java index 36d984b7eb99b..4b0cf0aa65268 100644 --- a/server/src/test/java/org/opensearch/cluster/metadata/TemplateUpgradeServiceTests.java +++ b/server/src/test/java/org/opensearch/cluster/metadata/TemplateUpgradeServiceTests.java @@ -46,6 +46,7 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; @@ -224,8 +225,9 @@ public void testUpdateTemplates() { service.upgradesInProgress.set(additionsCount + deletionsCount + 2); // +2 to skip tryFinishUpgrade final ThreadContext threadContext = threadPool.getThreadContext(); + final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); service.upgradeTemplates(additions, deletions); } diff --git a/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java b/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java index bf11bcaf39a96..a5da137cc4690 100644 --- a/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java +++ b/server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java @@ -10,6 +10,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.telemetry.Telemetry; @@ -256,11 +257,12 @@ public void run() { public void testSpanNotPropagatedToChildSystemThreadContext() { final Span span = tracer.startSpan(SpanCreationContext.internal().name("test")); + final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext); try (SpanScope scope = tracer.withSpanInScope(span)) { try (StoredContext ignored = threadContext.stashContext()) { assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue()))); assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span)); - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue())); } } diff --git a/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java b/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java index e2acbcff3db16..7d9f5dd2a8ead 100644 --- a/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java +++ b/server/src/test/java/org/opensearch/transport/RemoteConnectionStrategyTests.java @@ -36,17 +36,21 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class RemoteConnectionStrategyTests extends OpenSearchTestCase { public void testStrategyChangeMeansThatStrategyMustBeRebuilt() { ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class)); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY ); @@ -60,9 +64,11 @@ public void testStrategyChangeMeansThatStrategyMustBeRebuilt() { public void testSameStrategyChangeMeansThatStrategyDoesNotNeedToBeRebuilt() { ClusterConnectionManager connectionManager = new ClusterConnectionManager(Settings.EMPTY, mock(Transport.class)); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY ); @@ -78,9 +84,11 @@ public void testChangeInConnectionProfileMeansTheStrategyMustBeRebuilt() { assertEquals(TimeValue.MINUS_ONE, connectionManager.getConnectionProfile().getPingInterval()); assertEquals(false, connectionManager.getConnectionProfile().getCompressionEnabled()); RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager("cluster-alias", connectionManager); + TransportService mockTransportService = mock(TransportService.class); + when(mockTransportService.getThreadPool()).thenReturn(mock(ThreadPool.class)); FakeConnectionStrategy first = new FakeConnectionStrategy( "cluster-alias", - mock(TransportService.class), + mockTransportService, remoteConnectionManager, RemoteConnectionStrategy.ConnectionStrategy.PROXY ); diff --git a/test/framework/src/main/java/org/opensearch/cluster/service/FakeThreadPoolClusterManagerService.java b/test/framework/src/main/java/org/opensearch/cluster/service/FakeThreadPoolClusterManagerService.java index 53ef595c7931e..91585bb4bee1e 100644 --- a/test/framework/src/main/java/org/opensearch/cluster/service/FakeThreadPoolClusterManagerService.java +++ b/test/framework/src/main/java/org/opensearch/cluster/service/FakeThreadPoolClusterManagerService.java @@ -41,6 +41,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.InternalThreadContextWrapper; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor; import org.opensearch.common.util.concurrent.ThreadContext; @@ -133,8 +134,9 @@ public void run() { taskInProgress = true; scheduledNextTask = false; final ThreadContext threadContext = threadPool.getThreadContext(); + final InternalThreadContextWrapper tcWrapper = InternalThreadContextWrapper.from(threadContext); try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { - threadContext.markAsSystemContext(); + tcWrapper.markAsSystemContext(); task.run(); } if (waitForPublish == false) {