Skip to content

Commit

Permalink
Use ThreadContextAccess
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks committed Jul 31, 2024
1 parent fccc486 commit c0324b4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.action.ActionType;
import org.opensearch.action.support.ContextPreservingActionListener;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;

Expand Down Expand Up @@ -65,7 +66,11 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
ActionListener<Response> listener
) {
final Supplier<ThreadContext.StoredContext> supplier = in().threadPool().getThreadContext().newRestorableContext(false);
try (ThreadContext.StoredContext ignore = in().threadPool().getThreadContext().stashWithOrigin(origin)) {
try (
ThreadContext.StoredContext ignore = ThreadContextAccess.doPrivileged(
() -> in().threadPool().getThreadContext().stashWithOrigin(origin)
)
) {
super.doExecute(action, request, new ContextPreservingActionListener<>(supplier, listener));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@
import org.opensearch.common.action.ActionFuture;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextAccess;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
Expand Down Expand Up @@ -2148,7 +2149,9 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
ActionListener<Response> listener
) {
ThreadContext threadContext = threadPool().getThreadContext();
try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(headers)) {
try (
ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(() -> threadContext.stashAndMergeHeaders(headers))

Check warning on line 2153 in server/src/main/java/org/opensearch/client/support/AbstractClient.java

View check run for this annotation

Codecov / codecov/patch

server/src/main/java/org/opensearch/client/support/AbstractClient.java#L2153

Added line #L2153 was not covered by tests
) {
super.doExecute(action, request, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ public void testStashWithOrigin() {
}

assertNull(threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
try (ThreadContext.StoredContext storedContext = threadContext.stashWithOrigin(origin)) {
try (ThreadContext.StoredContext storedContext = ThreadContextAccess.doPrivileged(() -> threadContext.stashWithOrigin(origin))) {
assertEquals(origin, threadContext.getTransient(ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("bar"));
Expand All @@ -231,7 +231,7 @@ public void testStashAndMerge() {
HashMap<String, String> toMerge = new HashMap<>();
toMerge.put("foo", "baz");
toMerge.put("simon", "says");
try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
try (ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(() -> threadContext.stashAndMergeHeaders(toMerge))) {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("says", threadContext.getHeader("simon"));
assertNull(threadContext.getTransient("ctx.foo"));
Expand Down Expand Up @@ -493,7 +493,13 @@ public void testStashAndMergeWithModifiedDefaults() {
ThreadContext threadContext = new ThreadContext(build);
HashMap<String, String> toMerge = new HashMap<>();
toMerge.put("default", "2");
try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
ThreadContext finalThreadContext1 = threadContext;
HashMap<String, String> finalToMerge1 = toMerge;
try (
ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(
() -> finalThreadContext1.stashAndMergeHeaders(finalToMerge1)
)
) {
assertEquals("2", threadContext.getHeader("default"));
}

Expand All @@ -502,7 +508,13 @@ public void testStashAndMergeWithModifiedDefaults() {
threadContext.putHeader("default", "4");
toMerge = new HashMap<>();
toMerge.put("default", "2");
try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) {
ThreadContext finalThreadContext2 = threadContext;
HashMap<String, String> finalToMerge2 = toMerge;
try (
ThreadContext.StoredContext ctx = ThreadContextAccess.doPrivileged(
() -> finalThreadContext2.stashAndMergeHeaders(finalToMerge2)
)
) {
assertEquals("4", threadContext.getHeader("default"));
}
}
Expand Down

0 comments on commit c0324b4

Please sign in to comment.