Skip to content

Commit

Permalink
Extract GraphQlWebSocketMessage and WebSocketCodecDelegate
Browse files Browse the repository at this point in the history
Allows some reuse between WebFlux and WebMVC and between client and server.

See gh-10
  • Loading branch information
rstoyanchev committed Feb 24, 2022
1 parent afb0b64 commit 6c0bbdd
Show file tree
Hide file tree
Showing 6 changed files with 469 additions and 356 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import graphql.ErrorType;
import graphql.ExecutionResult;
import graphql.GraphQLError;
import graphql.GraphqlErrorBuilder;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -35,23 +34,12 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.graphql.web.WebGraphQlHandler;
import org.springframework.graphql.web.WebInput;
import org.springframework.graphql.web.WebOutput;
import org.springframework.http.MediaType;
import org.springframework.http.codec.DecoderHttpMessageReader;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.lang.Nullable;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
Expand All @@ -72,57 +60,36 @@ public class GraphQlWebSocketHandler implements WebSocketHandler {

private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");

static final ResolvableType MAP_RESOLVABLE_TYPE =
ResolvableType.forType(new ParameterizedTypeReference<Map<String, Object>>() {});


private final WebGraphQlHandler graphQlHandler;

private final Decoder<?> decoder;

private final Encoder<?> encoder;
private final WebSocketCodecDelegate codecDelegate;

private final Duration initTimeoutDuration;


/**
* Create a new instance.
* @param graphQlHandler common handler for GraphQL over WebSocket requests
* @param configurer codec configurer for JSON encoding and decoding
* @param codecConfigurer codec configurer for JSON encoding and decoding
* @param connectionInitTimeout the time within which the {@code CONNECTION_INIT} type
* message must be received.
*/
public GraphQlWebSocketHandler(
WebGraphQlHandler graphQlHandler, ServerCodecConfigurer configurer,
Duration connectionInitTimeout) {
WebGraphQlHandler graphQlHandler, CodecConfigurer codecConfigurer, Duration connectionInitTimeout) {

Assert.notNull(graphQlHandler, "WebGraphQlHandler is required");
this.graphQlHandler = graphQlHandler;
this.decoder = initDecoder(configurer);
this.encoder = initEncoder(configurer);
this.codecDelegate = new WebSocketCodecDelegate(codecConfigurer);
this.initTimeoutDuration = connectionInitTimeout;
}

private static Decoder<?> initDecoder(ServerCodecConfigurer configurer) {
return configurer.getReaders().stream()
.filter((reader) -> reader.canRead(MAP_RESOLVABLE_TYPE, MediaType.APPLICATION_JSON))
.map((reader) -> ((DecoderHttpMessageReader<?>) reader).getDecoder())
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("No JSON Decoder"));
}

private static Encoder<?> initEncoder(ServerCodecConfigurer configurer) {
return configurer.getWriters().stream()
.filter((writer) -> writer.canWrite(MAP_RESOLVABLE_TYPE, MediaType.APPLICATION_JSON))
.map((writer) -> ((EncoderHttpMessageWriter<?>) writer).getEncoder())
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("No JSON Encoder"));
}

@Override
public List<String> getSubProtocols() {
return SUB_PROTOCOL_LIST;
}


@Override
public Mono<Void> handle(WebSocketSession session) {
HandshakeInfo handshakeInfo = session.getHandshakeInfo();
Expand All @@ -145,62 +112,49 @@ public Mono<Void> handle(WebSocketSession session) {
Mono.empty()))
.subscribe();

return session.send(session.receive().flatMap((message) -> {
Map<String, Object> map = decode(message);
String id = (String) map.get("id");
MessageType messageType = MessageType.resolve((String) map.get("type"));
if (messageType == null) {
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
switch (messageType) {
case SUBSCRIBE:
if (!connectionInitProcessed.get()) {
return GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS);
}
if (id == null) {
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
WebInput input = new WebInput(
handshakeInfo.getUri(), handshakeInfo.getHeaders(), getPayload(map), null, id);
if (logger.isDebugEnabled()) {
logger.debug("Executing: " + input);
}
return this.graphQlHandler.handleRequest(input)
.flatMapMany((output) -> handleWebOutput(session, id, subscriptions, output))
.doOnTerminate(() -> subscriptions.remove(id));
case COMPLETE:
if (id != null) {
Subscription subscription = subscriptions.remove(id);
if (subscription != null) {
subscription.cancel();
return session.send(session.receive().flatMap(webSocketMessage -> {
GraphQlWebSocketMessage message = this.codecDelegate.decode(webSocketMessage);
String id = message.getId();
Map<String, Object> payload = message.getPayloadOrDefault(Collections.emptyMap());
switch (message.getType()) {
case "subscribe":
if (!connectionInitProcessed.get()) {
return GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS);
}
if (id == null) {
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
}
return this.graphQlHandler.handleWebSocketCompletion().thenMany(Flux.empty());
case CONNECTION_INIT:
if (!connectionInitProcessed.compareAndSet(false, true)) {
return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
}
return this.graphQlHandler.handleWebSocketInitialization(getPayload(map))
.defaultIfEmpty(Collections.emptyMap())
.flatMapMany(ackPayload -> Flux.just(encode(session, null, MessageType.CONNECTION_ACK, ackPayload)))
.onErrorResume(ex -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
default:
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
WebInput input = new WebInput(
handshakeInfo.getUri(), handshakeInfo.getHeaders(), payload, null, id);
if (logger.isDebugEnabled()) {
logger.debug("Executing: " + input);
}
return this.graphQlHandler.handleRequest(input)
.flatMapMany((output) -> handleWebOutput(session, id, subscriptions, output))
.doOnTerminate(() -> subscriptions.remove(id));
case "complete":
if (id != null) {
Subscription subscription = subscriptions.remove(id);
if (subscription != null) {
subscription.cancel();
}
}
return this.graphQlHandler.handleWebSocketCompletion().thenMany(Flux.empty());
case "connection_init":
if (!connectionInitProcessed.compareAndSet(false, true)) {
return GraphQlStatus.close(session, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
}
return this.graphQlHandler.handleWebSocketInitialization(payload)
.defaultIfEmpty(Collections.emptyMap())
.map(ackPayload -> this.codecDelegate.encodeConnectionAckMessage(session, ackPayload))
.flux()
.onErrorResume(ex -> GraphQlStatus.close(session, GraphQlStatus.UNAUTHORIZED_STATUS));
default:
return GraphQlStatus.close(session, GraphQlStatus.INVALID_MESSAGE_STATUS);
}
}));
}

@SuppressWarnings({ "unchecked", "ConstantConditions" })
private Map<String, Object> decode(WebSocketMessage message) {
DataBuffer buffer = DataBufferUtils.retain(message.getPayload());
return (Map<String, Object>) this.decoder.decode(buffer, MAP_RESOLVABLE_TYPE, null, null);
}

@SuppressWarnings("unchecked")
private static Map<String, Object> getPayload(Map<String, Object> message) {
Map<String, Object> payload = (Map<String, Object>) message.get("payload");
return (payload != null ? payload : Collections.emptyMap());
}

@SuppressWarnings("unchecked")
private Flux<WebSocketMessage> handleWebOutput(WebSocketSession session, String id,
Expand Down Expand Up @@ -230,78 +184,17 @@ private Flux<WebSocketMessage> handleWebOutput(WebSocketSession session, String
}

return outputFlux
.map((result) -> {
Map<String, Object> dataMap = result.toSpecification();
return encode(session, id, MessageType.NEXT, dataMap);
})
.concatWith(Mono.fromCallable(() -> encode(session, id, MessageType.COMPLETE, null)))
.onErrorResume((ex) -> {
.map(result -> this.codecDelegate.encodeNextMessage(session, id, result))
.concatWith(Mono.fromCallable(() -> this.codecDelegate.encodeCompleteMessage(session, id)))
.onErrorResume(ex -> {
if (ex instanceof SubscriptionExistsException) {
CloseStatus status = new CloseStatus(4409, "Subscriber for " + id + " already exists");
return GraphQlStatus.close(session, status);
}
Map<String, Object> errorMap = GraphqlErrorBuilder.newError()
.errorType(ErrorType.DataFetchingException)
.message(ex.getMessage())
.build()
.toSpecification();
return Mono.just(encode(
session, id, MessageType.ERROR, Collections.singletonList(errorMap)));
return Mono.fromCallable(() -> this.codecDelegate.encodeErrorMessage(session, id, ex));
});
}

@SuppressWarnings("unchecked")
private <T> WebSocketMessage encode(WebSocketSession session, @Nullable String id, MessageType messageType,
@Nullable Object payload) {

Map<String, Object> payloadMap = new HashMap<>(3);
if (id != null) {
payloadMap.put("id", id);
}
payloadMap.put("type", messageType.getType());
if (payload != null) {
payloadMap.put("payload", payload);
}

DataBuffer buffer = ((Encoder<T>) this.encoder).encodeValue((T) payloadMap, session.bufferFactory(),
MAP_RESOLVABLE_TYPE, MimeTypeUtils.APPLICATION_JSON, null);

return new WebSocketMessage(WebSocketMessage.Type.TEXT, buffer);
}

private enum MessageType {

CONNECTION_INIT("connection_init"),
CONNECTION_ACK("connection_ack"),
SUBSCRIBE("subscribe"),
NEXT("next"),
ERROR("error"),
COMPLETE("complete");

private static final Map<String, MessageType> messageTypes = new HashMap<>(6);

static {
for (MessageType messageType : MessageType.values()) {
messageTypes.put(messageType.getType(), messageType);
}
}

private final String type;

MessageType(String type) {
this.type = type;
}

public String getType() {
return this.type;
}

@Nullable
public static MessageType resolve(@Nullable String type) {
return (type != null) ? messageTypes.get(type) : null;
}

}

private static class GraphQlStatus {

Expand All @@ -319,9 +212,9 @@ static <V> Flux<V> close(WebSocketSession session, CloseStatus status) {

}


@SuppressWarnings("serial")
private static class SubscriptionExistsException extends RuntimeException {

}

}
Loading

0 comments on commit 6c0bbdd

Please sign in to comment.