Skip to content

Commit

Permalink
Propagate context in WebMvc GraphQlWebSocketHandler
Browse files Browse the repository at this point in the history
See gh-342
  • Loading branch information
rstoyanchev committed Apr 12, 2022
1 parent e2948e9 commit 85dc2e5
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,21 +16,24 @@

package org.springframework.graphql.execution;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;


/**
* Default implementation of a composite accessor that is returned from
* {@link ThreadLocalAccessor#composite(List)}.
*
* @author Rossen Stoyanchev
* @since 1.0.0
*/
class CompositeThreadLocalAccessor implements ThreadLocalAccessor {

private final List<ThreadLocalAccessor> accessors;

CompositeThreadLocalAccessor(List<ThreadLocalAccessor> accessors) {
this.accessors = accessors;
this.accessors = new ArrayList<>(accessors);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,34 @@ public WebGraphQlHandler build() {
.map(interceptor -> interceptor.apply(endOfChain))
.orElse(endOfChain);

ThreadLocalAccessor accessor = (CollectionUtils.isEmpty(this.accessors) ? null :
ThreadLocalAccessor.composite(this.accessors));

return new WebGraphQlHandler() {

@Override
public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
return (webSocketInterceptor != null ?
webSocketInterceptor : new WebSocketGraphQlInterceptor() {});
}

@Nullable
@Override
public ThreadLocalAccessor getThreadLocalAccessor() {
return accessor;
}

@Override
public Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request) {
return executionChain.next(request)
.contextWrite(context -> {
if (!CollectionUtils.isEmpty(accessors)) {
ThreadLocalAccessor accessor = ThreadLocalAccessor.composite(accessors);
if (accessor != null) {
return ReactorContextManager.extractThreadLocalValues(accessor, context);
}
return context;
});
}

@Override
public WebSocketGraphQlInterceptor webSocketInterceptor() {
return (webSocketInterceptor != null ?
webSocketInterceptor : new WebSocketGraphQlInterceptor() {});
}

};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import org.springframework.graphql.ExecutionGraphQlService;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.lang.Nullable;


/**
Expand All @@ -33,6 +34,19 @@
*/
public interface WebGraphQlHandler {

/**
* Return the single interceptor of type
* {@link WebSocketGraphQlInterceptor} among all the configured
* interceptors.
*/
WebSocketGraphQlInterceptor getWebSocketInterceptor();

/**
* Return the composite {@link ThreadLocalAccessor} that the handler is
* configured with.
*/
@Nullable
ThreadLocalAccessor getThreadLocalAccessor();

/**
* Execute the given request and return the response.
Expand All @@ -41,13 +55,6 @@ public interface WebGraphQlHandler {
*/
Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request);

/**
* Return the single interceptor of type
* {@link WebSocketGraphQlInterceptor} among all the configured
* interceptors.
*/
WebSocketGraphQlInterceptor webSocketInterceptor();


/**
* Provides access to a builder to create a {@link WebGraphQlHandler} instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public GraphQlWebSocketHandler(
Assert.notNull(graphQlHandler, "WebGraphQlHandler is required");

this.graphQlHandler = graphQlHandler;
this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor();
this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor();
this.codecDelegate = new CodecDelegate(codecConfigurer);
this.initTimeoutDuration = connectionInitTimeout;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -43,6 +45,7 @@
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlRequest;
import org.springframework.graphql.server.WebGraphQlResponse;
Expand All @@ -53,15 +56,21 @@
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;

/**
* WebSocketHandler for GraphQL based on
Expand All @@ -81,7 +90,9 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub

private final WebGraphQlHandler graphQlHandler;

private final WebSocketGraphQlInterceptor webSocketInterceptor;
private final ContextHandshakeInterceptor contextHandshakeInterceptor;

private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor;

private final Duration initTimeoutDuration;

Expand All @@ -103,7 +114,8 @@ public GraphQlWebSocketHandler(
Assert.notNull(converter, "HttpMessageConverter for JSON is required");

this.graphQlHandler = graphQlHandler;
this.webSocketInterceptor = this.graphQlHandler.webSocketInterceptor();
this.contextHandshakeInterceptor = new ContextHandshakeInterceptor(graphQlHandler.getThreadLocalAccessor());
this.webSocketGraphQlInterceptor = this.graphQlHandler.getWebSocketInterceptor();
this.initTimeoutDuration = connectionInitTimeout;
this.converter = converter;
}
Expand All @@ -113,6 +125,18 @@ public List<String> getSubProtocols() {
return SUB_PROTOCOL_LIST;
}

/**
* Return a {@link WebSocketHttpRequestHandler} that uses this instance as
* its {@link WebGraphQlHandler} and adds a {@link HandshakeInterceptor} to
* propagate context.
*/
public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler(HandshakeHandler handshakeHandler) {
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(this, handshakeHandler);
handler.setHandshakeInterceptors(Collections.singletonList(this.contextHandshakeInterceptor));
return handler;
}


@Override
public void afterConnectionEstablished(WebSocketSession session) {
if ("graphql-ws".equalsIgnoreCase(session.getAcceptedProtocol())) {
Expand All @@ -137,8 +161,15 @@ public void afterConnectionEstablished(WebSocketSession session) {

}

@SuppressWarnings({"unused", "try"})
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage webSocketMessage) throws Exception {
try (Closeable closeable = this.contextHandshakeInterceptor.restoreThreadLocalValue(session)) {
handleInternal(session, webSocketMessage);
}
}

private void handleInternal(WebSocketSession session, TextMessage webSocketMessage) throws IOException {
GraphQlWebSocketMessage message = decode(webSocketMessage);
String id = message.getId();
Map<String, Object> payload = message.getPayload();
Expand Down Expand Up @@ -174,7 +205,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
if (subscription != null) {
subscription.cancel();
}
this.webSocketInterceptor.handleCancelledSubscription(session.getId(), id)
this.webSocketGraphQlInterceptor.handleCancelledSubscription(session.getId(), id)
.block(Duration.ofSeconds(10));
}
return;
Expand All @@ -183,7 +214,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
GraphQlStatus.closeSession(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
return;
}
this.webSocketInterceptor.handleConnectionInitialization(session.getId(), payload)
this.webSocketGraphQlInterceptor.handleConnectionInitialization(session.getId(), payload)
.defaultIfEmpty(Collections.emptyMap())
.publishOn(sessionState.getScheduler()) // Serial blocking send via single thread
.doOnNext(ackPayload -> {
Expand Down Expand Up @@ -285,7 +316,7 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeSta
info.dispose();
Map<String, Object> connectionInitPayload = info.getConnectionInitPayload();
if (connectionInitPayload != null) {
this.webSocketInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload);
this.webSocketGraphQlInterceptor.handleConnectionClosed(id, closeStatus.getCode(), connectionInitPayload);
}
}
}
Expand All @@ -296,6 +327,57 @@ public boolean supportsPartialMessages() {
}


/**
* {@code HandshakeInterceptor} that propagates ThreadLocal context through
* the attributes map in {@code WebSocketSession}.
*/
private static class ContextHandshakeInterceptor implements HandshakeInterceptor {

private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor.class.getName();

@Nullable
private final ThreadLocalAccessor accessor;

ContextHandshakeInterceptor(@Nullable ThreadLocalAccessor accessor) {
this.accessor = accessor;
}

@Override
public boolean beforeHandshake(
ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Map<String, Object> attributes) {

if (this.accessor != null) {
Map<String, Object> valuesMap = new LinkedHashMap<>();
this.accessor.extractValues(valuesMap);
attributes.put(SAVED_CONTEXT_KEY, valuesMap);
}
return true;
}

@Override
public void afterHandshake(
ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
@Nullable Exception exception) {
}

@SuppressWarnings("unchecked")
public Closeable restoreThreadLocalValue(WebSocketSession session) {
if (this.accessor != null) {
Map<String, Object> valuesMap = (Map<String, Object>) session.getAttributes().get(SAVED_CONTEXT_KEY);
// Uncomment when Boot is updated to use HandshakeInterceptor
// Assert.state(valuesMap != null, "No context");
if (valuesMap != null) {
this.accessor.restoreValues(valuesMap);
return () -> this.accessor.resetValues(valuesMap);
}
}
return () -> {};
}

}


private static class GraphQlStatus {

private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,8 @@

import org.springframework.graphql.BookSource;
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.execution.ThreadLocalAccessor;
import org.springframework.lang.Nullable;

public abstract class WebSocketHandlerTestSupport {

Expand Down Expand Up @@ -65,6 +67,12 @@ public abstract class WebSocketHandlerTestSupport {


protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) {
return initHandler(null, interceptors);
}

protected WebGraphQlHandler initHandler(
@Nullable ThreadLocalAccessor accessor, WebGraphQlInterceptor... interceptors) {

return GraphQlSetup.schemaResource(BookSource.schema)
.queryFetcher("bookById", environment -> {
Long id = Long.parseLong(environment.getArgument("id"));
Expand All @@ -75,6 +83,7 @@ protected WebGraphQlHandler initHandler(WebGraphQlInterceptor... interceptors) {
return Flux.fromIterable(BookSource.books())
.filter((book) -> book.getAuthor().getFullName().contains(author));
})
.threadLocalAccessor(accessor)
.interceptor(interceptors)
.toWebGraphQlHandler();
}
Expand Down
Loading

0 comments on commit 85dc2e5

Please sign in to comment.