diff --git a/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java b/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java index 7ae4d3ba7..55ddc3bba 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/execution/CompositeThreadLocalAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,24 @@ package org.springframework.graphql.execution; +import java.util.ArrayList; import java.util.List; import java.util.Map; + /** * Default implementation of a composite accessor that is returned from * {@link ThreadLocalAccessor#composite(List)}. * * @author Rossen Stoyanchev + * @since 1.0.0 */ class CompositeThreadLocalAccessor implements ThreadLocalAccessor { private final List accessors; CompositeThreadLocalAccessor(List accessors) { - this.accessors = accessors; + this.accessors = new ArrayList<>(accessors); } @Override diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/DefaultWebGraphQlHandlerBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/server/DefaultWebGraphQlHandlerBuilder.java index e1fa7353f..774ddae9a 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/DefaultWebGraphQlHandlerBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/DefaultWebGraphQlHandlerBuilder.java @@ -96,26 +96,34 @@ public WebGraphQlHandler build() { .map(interceptor -> interceptor.apply(endOfChain)) .orElse(endOfChain); + ThreadLocalAccessor accessor = (CollectionUtils.isEmpty(this.accessors) ? null : + ThreadLocalAccessor.composite(this.accessors)); + return new WebGraphQlHandler() { + @Override + public WebSocketGraphQlInterceptor getWebSocketInterceptor() { + return (webSocketInterceptor != null ? + webSocketInterceptor : new WebSocketGraphQlInterceptor() {}); + } + + @Nullable + @Override + public ThreadLocalAccessor getThreadLocalAccessor() { + return accessor; + } + @Override public Mono handleRequest(WebGraphQlRequest request) { return executionChain.next(request) .contextWrite(context -> { - if (!CollectionUtils.isEmpty(accessors)) { - ThreadLocalAccessor accessor = ThreadLocalAccessor.composite(accessors); + if (accessor != null) { return ReactorContextManager.extractThreadLocalValues(accessor, context); } return context; }); } - @Override - public WebSocketGraphQlInterceptor webSocketInterceptor() { - return (webSocketInterceptor != null ? - webSocketInterceptor : new WebSocketGraphQlInterceptor() {}); - } - }; } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlHandler.java index 392d2c389..02471343e 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlHandler.java @@ -22,6 +22,7 @@ import org.springframework.graphql.ExecutionGraphQlService; import org.springframework.graphql.execution.ThreadLocalAccessor; +import org.springframework.lang.Nullable; /** @@ -33,6 +34,19 @@ */ public interface WebGraphQlHandler { + /** + * Return the single interceptor of type + * {@link WebSocketGraphQlInterceptor} among all the configured + * interceptors. + */ + WebSocketGraphQlInterceptor getWebSocketInterceptor(); + + /** + * Return the composite {@link ThreadLocalAccessor} that the handler is + * configured with. + */ + @Nullable + ThreadLocalAccessor getThreadLocalAccessor(); /** * Execute the given request and return the response. @@ -41,13 +55,6 @@ public interface WebGraphQlHandler { */ Mono handleRequest(WebGraphQlRequest request); - /** - * Return the single interceptor of type - * {@link WebSocketGraphQlInterceptor} among all the configured - * interceptors. - */ - WebSocketGraphQlInterceptor webSocketInterceptor(); - /** * Provides access to a builder to create a {@link WebGraphQlHandler} instance. diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java index e4b7a5fa2..ab2cf86f5 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.java @@ -83,7 +83,7 @@ public GraphQlWebSocketHandler( Assert.notNull(graphQlHandler, "WebGraphQlHandler is required"); this.graphQlHandler = graphQlHandler; - this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor(); + this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor(); this.codecDelegate = new CodecDelegate(codecConfigurer); this.initTimeoutDuration = connectionInitTimeout; } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java index 87b0c381c..f4f397f20 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandler.java @@ -18,6 +18,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -25,6 +26,7 @@ import java.time.Duration; import java.util.Arrays; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -43,6 +45,7 @@ import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; +import org.springframework.graphql.execution.ThreadLocalAccessor; import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.WebGraphQlRequest; import org.springframework.graphql.server.WebGraphQlResponse; @@ -53,15 +56,21 @@ import org.springframework.http.HttpOutputMessage; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServerHttpRequest; +import org.springframework.http.server.ServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.SubProtocolCapable; import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator; import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.springframework.web.socket.server.HandshakeHandler; +import org.springframework.web.socket.server.HandshakeInterceptor; +import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; /** * WebSocketHandler for GraphQL based on @@ -81,7 +90,9 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub private final WebGraphQlHandler graphQlHandler; - private final WebSocketGraphQlInterceptor webSocketInterceptor; + private final ContextHandshakeInterceptor contextHandshakeInterceptor; + + private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor; private final Duration initTimeoutDuration; @@ -103,7 +114,8 @@ public GraphQlWebSocketHandler( Assert.notNull(converter, "HttpMessageConverter for JSON is required"); this.graphQlHandler = graphQlHandler; - this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor(); + this.contextHandshakeInterceptor = new ContextHandshakeInterceptor(graphQlHandler.getThreadLocalAccessor()); + this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor(); this.initTimeoutDuration = connectionInitTimeout; this.converter = converter; } @@ -113,6 +125,18 @@ public List getSubProtocols() { return SUB_PROTOCOL_LIST; } + /** + * Return a {@link WebSocketHttpRequestHandler} that uses this instance as + * its {@link WebGraphQlHandler} and adds a {@link HandshakeInterceptor} to + * propagate context. + */ + public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) { + WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this, handshakeHandler); + handler.setHandshakeInterceptors(Collections.singletonList(this.contextHandshakeInterceptor)); + return handler; + } + + @Override public void afterConnectionEstablished(WebSocketSession session) { if ("graphql-ws".equalsIgnoreCase(session.getAcceptedProtocol())) { @@ -137,8 +161,15 @@ public void afterConnectionEstablished(WebSocketSession session) { } + @SuppressWarnings({"unused", "try"}) @Override protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception { + try (Closeable closeable = this.contextHandshakeInterceptor.restoreThreadLocalValue(session)) { + handleInternal(session, webSocketMessage); + } + } + + private void handleInternal(WebSocketSession session, TextMessage webSocketMessage) throws IOException { GraphQlWebSocketMessage message = decode(webSocketMessage); String id = message.getId(); Map payload = message.getPayload(); @@ -174,7 +205,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket if (subscription != null) { subscription.cancel(); } - this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id) + this.webSocketGraphQlInterceptor.handleCancelledSubscription(session.getId(), id) .block(Duration.ofSeconds(10)); } return; @@ -183,7 +214,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS); return; } - this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload) + this.webSocketGraphQlInterceptor.handleConnectionInitialization(session.getId(), payload) .defaultIfEmpty(Collections.emptyMap()) .publishOn(sessionState.getScheduler()) // Serial blocking send via single thread .doOnNext(ackPayload -> { @@ -285,7 +316,7 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeSta info.dispose(); Map connectionInitPayload = info.getConnectionInitPayload(); if (connectionInitPayload != null) { - this.webSocketInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload); + this.webSocketGraphQlInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload); } } } @@ -296,6 +327,57 @@ public boolean supportsPartialMessages() { } + /** + * {@code HandshakeInterceptor} that propagates ThreadLocal context through + * the attributes map in {@code WebSocketSession}. + */ + private static class ContextHandshakeInterceptor implements HandshakeInterceptor { + + private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor.class.getName(); + + @Nullable + private final ThreadLocalAccessor accessor; + + ContextHandshakeInterceptor(@Nullable ThreadLocalAccessor accessor) { + this.accessor = accessor; + } + + @Override + public boolean beforeHandshake( + ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + Map attributes) { + + if (this.accessor != null) { + Map valuesMap = new LinkedHashMap<>(); + this.accessor.extractValues(valuesMap); + attributes.put(SAVED_CONTEXT_KEY, valuesMap); + } + return true; + } + + @Override + public void afterHandshake( + ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, + @Nullable Exception exception) { + } + + @SuppressWarnings("unchecked") + public Closeable restoreThreadLocalValue(WebSocketSession session) { + if (this.accessor != null) { + Map valuesMap = (Map) session.getAttributes().get(SAVED_CONTEXT_KEY); + // Uncomment when Boot is updated to use HandshakeInterceptor + // Assert.state(valuesMap != null, "No context"); + if (valuesMap != null) { + this.accessor.restoreValues(valuesMap); + return () -> this.accessor.resetValues(valuesMap); + } + } + return () -> {}; + } + + } + + private static class GraphQlStatus { private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message"); diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java b/spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java index f3324f263..b84078eeb 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/WebSocketHandlerTestSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import org.springframework.graphql.BookSource; import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.execution.ThreadLocalAccessor; +import org.springframework.lang.Nullable; public abstract class WebSocketHandlerTestSupport { @@ -65,6 +67,12 @@ public abstract class WebSocketHandlerTestSupport { protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) { + return initHandler(null, interceptors); + } + + protected WebGraphQlHandler initHandler( + @Nullable ThreadLocalAccessor accessor, WebGraphQlInterceptor... interceptors) { + return GraphQlSetup.schemaResource(BookSource.schema) .queryFetcher("bookById", environment -> { Long id = Long.parseLong(environment.getArgument("id")); @@ -75,6 +83,7 @@ protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) { return Flux.fromIterable(BookSource.books()) .filter((book) -> book.getAuthor().getFullName().contains(author)); }) + .threadLocalAccessor(accessor) .interceptor(interceptors) .toWebGraphQlHandler(); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java index 00f97c31d..7afad4153 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlWebSocketHandlerTests.java @@ -35,18 +35,21 @@ import reactor.test.StepVerifier; import org.springframework.graphql.GraphQlSetup; +import org.springframework.graphql.TestThreadLocalAccessor; +import org.springframework.graphql.execution.ThreadLocalAccessor; +import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor; import org.springframework.graphql.server.WebGraphQlHandler; import org.springframework.graphql.server.WebGraphQlInterceptor; import org.springframework.graphql.server.WebSocketGraphQlInterceptor; +import org.springframework.graphql.server.WebSocketHandlerTestSupport; import org.springframework.graphql.server.support.GraphQlWebSocketMessage; import org.springframework.graphql.server.support.GraphQlWebSocketMessageType; -import org.springframework.graphql.server.ConsumeOneAndNeverCompleteInterceptor; -import org.springframework.graphql.server.WebSocketHandlerTestSupport; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpInputMessage; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter; +import org.springframework.lang.Nullable; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; @@ -371,6 +374,45 @@ void errorMessagePayloadIsCorrectArray() throws Exception { .verify(TIMEOUT); } + @Test + void contextPropagation() throws Exception { + ThreadLocal threadLocal = new ThreadLocal<>(); + threadLocal.set("foo"); + + WebGraphQlInterceptor threadLocalInterceptor = (request, chain) -> { + assertThat(threadLocal.get()).isEqualTo("foo"); + return chain.next(request); + }; + + GraphQlWebSocketHandler handler = initWebSocketHandler( + new TestThreadLocalAccessor<>(threadLocal), threadLocalInterceptor); + + // Use HandshakeInterceptor to capture ThreadLocal context + handler.asWebSocketHttpRequestHandler((request, response, wsHandler, attributes) -> false) + .getHandshakeInterceptors().get(0) + .beforeHandshake(null, null, null, this.session.getAttributes()); + + // Context should propagate, if message is handled on different thread + Thread thread = new Thread(() -> { + try { + handle(handler, + new TextMessage("{\"type\":\"connection_init\"}"), + new TextMessage(BOOK_QUERY)); + } + catch (Exception ex) { + throw new IllegalStateException(ex); + } + }); + thread.start(); + + StepVerifier.create(this.session.getOutput()) + .expectNextCount(2) + .consumeNextWith((message) -> assertMessageType(message, GraphQlWebSocketMessageType.COMPLETE)) + .then(this.session::close) // Complete output Flux + .expectComplete() + .verify(TIMEOUT); + } + private void handle(GraphQlWebSocketHandler handler, TextMessage... textMessages) throws Exception { handler.afterConnectionEstablished(this.session); for (TextMessage message : textMessages) { @@ -379,8 +421,15 @@ private void handle(GraphQlWebSocketHandler handler, TextMessage... textMessages } private GraphQlWebSocketHandler initWebSocketHandler(WebGraphQlInterceptor... interceptors) { + return initWebSocketHandler(null, interceptors); + } + + private GraphQlWebSocketHandler initWebSocketHandler( + @Nullable ThreadLocalAccessor accessor, WebGraphQlInterceptor... interceptors) { + try { - return new GraphQlWebSocketHandler(initHandler(interceptors), converter, Duration.ofSeconds(60)); + return new GraphQlWebSocketHandler( + initHandler(accessor, interceptors), converter, Duration.ofSeconds(60)); } catch (Exception ex) { throw new IllegalStateException(ex); diff --git a/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java b/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java index c1724a647..df712b7bf 100644 --- a/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java +++ b/spring-graphql/src/testFixtures/java/org/springframework/graphql/GraphQlSetup.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,10 +35,11 @@ import org.springframework.graphql.execution.GraphQlSource; import org.springframework.graphql.execution.RuntimeWiringConfigurer; import org.springframework.graphql.execution.ThreadLocalAccessor; -import org.springframework.graphql.server.webflux.GraphQlHttpHandler; import org.springframework.graphql.server.WebGraphQlHandler; -import org.springframework.graphql.server.WebGraphQlSetup; import org.springframework.graphql.server.WebGraphQlInterceptor; +import org.springframework.graphql.server.WebGraphQlSetup; +import org.springframework.graphql.server.webflux.GraphQlHttpHandler; +import org.springframework.lang.Nullable; /** * Workflow for GraphQL tests setup that starts with {@link GraphQlSource.Builder} @@ -141,8 +142,10 @@ public WebGraphQlSetup interceptor(WebGraphQlInterceptor... interceptors) { } @Override - public WebGraphQlSetup threadLocalAccessor(ThreadLocalAccessor... accessors) { - this.accessors.addAll(Arrays.asList(accessors)); + public WebGraphQlSetup threadLocalAccessor(@Nullable ThreadLocalAccessor accessor) { + if (accessor != null) { + this.accessors.add(accessor); + } return this; } diff --git a/spring-graphql/src/testFixtures/java/org/springframework/graphql/server/WebGraphQlSetup.java b/spring-graphql/src/testFixtures/java/org/springframework/graphql/server/WebGraphQlSetup.java index 627f574d9..5f1982528 100644 --- a/spring-graphql/src/testFixtures/java/org/springframework/graphql/server/WebGraphQlSetup.java +++ b/spring-graphql/src/testFixtures/java/org/springframework/graphql/server/WebGraphQlSetup.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ import org.springframework.graphql.execution.ThreadLocalAccessor; import org.springframework.graphql.server.webflux.GraphQlHttpHandler; +import org.springframework.lang.Nullable; /** * Workflow that results in the creation of a {@link WebGraphQlHandler} or @@ -28,7 +29,7 @@ public interface WebGraphQlSetup { WebGraphQlSetup interceptor(WebGraphQlInterceptor... interceptors); - WebGraphQlSetup threadLocalAccessor(ThreadLocalAccessor... accessors); + WebGraphQlSetup threadLocalAccessor(@Nullable ThreadLocalAccessor accessor); WebGraphQlHandler toWebGraphQlHandler();