Skip to content

Commit

Permalink
Insert WebSocket in the names of GraphQlMessage[Type]
Browse files Browse the repository at this point in the history
Those are specific to the GraphQL over WebSocket protocol.

See gh-339
  • Loading branch information
rstoyanchev committed Mar 28, 2022
1 parent eb9a369 commit 3090328
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.graphql.server.support.GraphQlMessage;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ClientCodecConfigurer;
import org.springframework.http.codec.CodecConfigurer;
Expand All @@ -42,7 +42,7 @@
*/
final class CodecDelegate {

private static final ResolvableType MESSAGE_TYPE = ResolvableType.forClass(GraphQlMessage.class);
private static final ResolvableType MESSAGE_TYPE = ResolvableType.forClass(GraphQlWebSocketMessage.class);


private final CodecConfigurer codecConfigurer;
Expand Down Expand Up @@ -104,7 +104,7 @@ public CodecConfigurer getCodecConfigurer() {


@SuppressWarnings("unchecked")
public <T> WebSocketMessage encode(WebSocketSession session, GraphQlMessage message) {
public <T> WebSocketMessage encode(WebSocketSession session, GraphQlWebSocketMessage message) {

DataBuffer buffer = ((Encoder<T>) this.encoder).encodeValue(
(T) message, session.bufferFactory(), MESSAGE_TYPE, MimeTypeUtils.APPLICATION_JSON, null);
Expand All @@ -113,9 +113,9 @@ public <T> WebSocketMessage encode(WebSocketSession session, GraphQlMessage mess
}

@SuppressWarnings("ConstantConditions")
public GraphQlMessage decode(WebSocketMessage webSocketMessage) {
public GraphQlWebSocketMessage decode(WebSocketMessage webSocketMessage) {
DataBuffer buffer = DataBufferUtils.retain(webSocketMessage.getPayload());
return (GraphQlMessage) this.decoder.decode(buffer, MESSAGE_TYPE, null, null);
return (GraphQlWebSocketMessage) this.decoder.decode(buffer, MESSAGE_TYPE, null, null);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import org.springframework.graphql.GraphQlRequest;
import org.springframework.graphql.GraphQlResponse;
import org.springframework.graphql.ResponseError;
import org.springframework.graphql.server.support.GraphQlMessage;
import org.springframework.graphql.server.support.GraphQlMessageType;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.graphql.server.support.GraphQlWebSocketMessageType;
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -226,9 +226,9 @@ public Mono<Void> handle(WebSocketSession session) {
GraphQlSession graphQlSession = new GraphQlSession(session);
registerCloseStatusHandling(graphQlSession, session);

Mono<GraphQlMessage> connectionInitMono = this.interceptor.connectionInitPayload()
Mono<GraphQlWebSocketMessage> connectionInitMono = this.interceptor.connectionInitPayload()
.defaultIfEmpty(Collections.emptyMap())
.map(GraphQlMessage::connectionInit);
.map(GraphQlWebSocketMessage::connectionInit);

Mono<Void> sendCompletion =
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
Expand All @@ -238,8 +238,8 @@ public Mono<Void> handle(WebSocketSession session) {
.flatMap(webSocketMessage -> {
if (sessionNotInitialized()) {
try {
GraphQlMessage message = this.codecDelegate.decode(webSocketMessage);
Assert.state(message.resolvedType() == GraphQlMessageType.CONNECTION_ACK,
GraphQlWebSocketMessage message = this.codecDelegate.decode(webSocketMessage);
Assert.state(message.resolvedType() == GraphQlWebSocketMessageType.CONNECTION_ACK,
() -> "Unexpected message before connection_ack: " + message);
return this.interceptor.handleConnectionAck(message.getPayload())
.then(Mono.defer(() -> {
Expand All @@ -261,7 +261,7 @@ public Mono<Void> handle(WebSocketSession session) {
}
else {
try {
GraphQlMessage message = this.codecDelegate.decode(webSocketMessage);
GraphQlWebSocketMessage message = this.codecDelegate.decode(webSocketMessage);
switch (message.resolvedType()) {
case NEXT:
graphQlSession.handleNext(message);
Expand Down Expand Up @@ -378,7 +378,7 @@ private static class GraphQlSession {

private final AtomicLong requestIndex = new AtomicLong();

private final Sinks.Many<GraphQlMessage> requestSink = Sinks.many().unicast().onBackpressureBuffer();
private final Sinks.Many<GraphQlWebSocketMessage> requestSink = Sinks.many().unicast().onBackpressureBuffer();

private final Map<String, ResponseState> responseMap = new ConcurrentHashMap<>();

Expand All @@ -393,14 +393,14 @@ private static class GraphQlSession {
/**
* Return the {@code Flux} of GraphQL requests to send as WebSocket messages.
*/
public Flux<GraphQlMessage> getRequestFlux() {
public Flux<GraphQlWebSocketMessage> getRequestFlux() {
return this.requestSink.asFlux();
}

public Mono<GraphQlResponse> execute(GraphQlRequest request) {
String id = String.valueOf(this.requestIndex.incrementAndGet());
try {
GraphQlMessage message = GraphQlMessage.subscribe(id, request);
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.subscribe(id, request);
ResponseState state = new ResponseState(request);
this.responseMap.put(id, state);
trySend(message);
Expand All @@ -415,7 +415,7 @@ public Mono<GraphQlResponse> execute(GraphQlRequest request) {
public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
String id = String.valueOf(this.requestIndex.incrementAndGet());
try {
GraphQlMessage message = GraphQlMessage.subscribe(id, request);
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.subscribe(id, request);
SubscriptionState state = new SubscriptionState(request);
this.subscriptionMap.put(id, state);
trySend(message);
Expand All @@ -428,13 +428,13 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
}

public void sendPong(@Nullable Map<String, Object> payload) {
GraphQlMessage message = GraphQlMessage.pong(payload);
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.pong(payload);
trySend(message);
}

// TODO: queue to serialize sending?

private void trySend(GraphQlMessage message) {
private void trySend(GraphQlWebSocketMessage message) {
Sinks.EmitResult emitResult = null;
for (int i = 0; i < 100; i++) {
emitResult = this.requestSink.tryEmitNext(message);
Expand All @@ -449,7 +449,7 @@ private void stopSubscription(String id) {
SubscriptionState state = this.subscriptionMap.remove(id);
if (state != null) {
try {
trySend(GraphQlMessage.complete(id));
trySend(GraphQlWebSocketMessage.complete(id));
}
catch (Exception ex) {
if (logger.isErrorEnabled()) {
Expand All @@ -465,7 +465,7 @@ private void stopSubscription(String id) {
/**
* Handle a "next" message and route to its recipient.
*/
public void handleNext(GraphQlMessage message) {
public void handleNext(GraphQlWebSocketMessage message) {
String id = message.getId();
ResponseState responseState = this.responseMap.remove(id);
SubscriptionState subscriptionState = this.subscriptionMap.get(id);
Expand Down Expand Up @@ -496,7 +496,7 @@ public void handleNext(GraphQlMessage message) {
* Handle an "error" message, turning it into an {@link GraphQlResponse}
* for single responses, or signaling an error for streams.
*/
public void handleError(GraphQlMessage message) {
public void handleError(GraphQlWebSocketMessage message) {
String id = message.getId();
ResponseState responseState = this.responseMap.remove(id);
SubscriptionState subscriptionState = this.subscriptionMap.remove(id);
Expand Down Expand Up @@ -529,7 +529,7 @@ public void handleError(GraphQlMessage message) {
/**
* Handle a "complete" message.
*/
public void handleComplete(GraphQlMessage message) {
public void handleComplete(GraphQlWebSocketMessage message) {
ResponseState responseState = this.responseMap.remove(message.getId());
SubscriptionState subscriptionState = this.subscriptionMap.remove(message.getId());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
* @since 1.0.0
* @see <a href="https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md">GraphQL Over WebSocket Protocol</a>
*/
public class GraphQlMessage {
public class GraphQlWebSocketMessage {

@Nullable
private String id;

@Nullable
private GraphQlMessageType type;
private GraphQlWebSocketMessageType type;

@Nullable
private Object payload;
Expand All @@ -50,7 +50,7 @@ public class GraphQlMessage {
/**
* Private constructor. See static factory methods.
*/
private GraphQlMessage(@Nullable String id, GraphQlMessageType type, @Nullable Object payload) {
private GraphQlWebSocketMessage(@Nullable String id, GraphQlWebSocketMessageType type, @Nullable Object payload) {
Assert.notNull(type, "GraphQlMessageType is required");
Assert.isTrue(payload != null || type.doesNotRequirePayload(), "Payload is required for [" + type + "]");
this.id = id;
Expand All @@ -63,8 +63,8 @@ private GraphQlMessage(@Nullable String id, GraphQlMessageType type, @Nullable O
* Constructor for deserialization.
*/
@SuppressWarnings("unused")
GraphQlMessage() {
this.type = GraphQlMessageType.NOT_SPECIFIED;
GraphQlWebSocketMessage() {
this.type = GraphQlWebSocketMessageType.NOT_SPECIFIED;
}


Expand All @@ -88,7 +88,7 @@ public String getType() {
/**
* Return the message type as an emum.
*/
public GraphQlMessageType resolvedType() {
public GraphQlWebSocketMessageType resolvedType() {
Assert.state(this.type != null, "GraphQlWebSocketMessage does not have a type");
return this.type;
}
Expand All @@ -111,7 +111,7 @@ public void setId(@Nullable String id) {
}

public void setType(String type) {
this.type = GraphQlMessageType.fromValue(type);
this.type = GraphQlWebSocketMessageType.fromValue(type);
}

public void setPayload(@Nullable Object payload) {
Expand All @@ -129,10 +129,10 @@ public int hashCode() {

@Override
public boolean equals(Object o) {
if (!(o instanceof GraphQlMessage)) {
if (!(o instanceof GraphQlWebSocketMessage)) {
return false;
}
GraphQlMessage other = (GraphQlMessage) o;
GraphQlWebSocketMessage other = (GraphQlWebSocketMessage) o;
return (ObjectUtils.nullSafeEquals(this.type, other.type) &&
(ObjectUtils.nullSafeEquals(this.id, other.id) || (this.id == null && other.id == null)) &&
(ObjectUtils.nullSafeEquals(getPayload(), other.getPayload())));
Expand All @@ -151,71 +151,71 @@ public String toString() {
* Create a {@code "connection_init"} client message.
* @param payload an optional payload
*/
public static GraphQlMessage connectionInit(@Nullable Object payload) {
return new GraphQlMessage(null, GraphQlMessageType.CONNECTION_INIT, payload);
public static GraphQlWebSocketMessage connectionInit(@Nullable Object payload) {
return new GraphQlWebSocketMessage(null, GraphQlWebSocketMessageType.CONNECTION_INIT, payload);
}

/**
* Create a {@code "connection_ack"} server message.
* @param payload an optional payload
*/
public static GraphQlMessage connectionAck(@Nullable Object payload) {
return new GraphQlMessage(null, GraphQlMessageType.CONNECTION_ACK, payload);
public static GraphQlWebSocketMessage connectionAck(@Nullable Object payload) {
return new GraphQlWebSocketMessage(null, GraphQlWebSocketMessageType.CONNECTION_ACK, payload);
}

/**
* Create a {@code "subscribe"} client message.
* @param id unique request id
* @param request the request to add as the message payload
*/
public static GraphQlMessage subscribe(String id, GraphQlRequest request) {
public static GraphQlWebSocketMessage subscribe(String id, GraphQlRequest request) {
Assert.notNull(request, "GraphQlRequest is required");
return new GraphQlMessage(id, GraphQlMessageType.SUBSCRIBE, request.toMap());
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.SUBSCRIBE, request.toMap());
}

/**
* Create a {@code "next"} server message.
* @param id unique request id
* @param responseMap the response map
*/
public static GraphQlMessage next(String id, Map<String, Object> responseMap) {
public static GraphQlWebSocketMessage next(String id, Map<String, Object> responseMap) {
Assert.notNull(responseMap, "'responseMap' is required");
return new GraphQlMessage(id, GraphQlMessageType.NEXT, responseMap);
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.NEXT, responseMap);
}

/**
* Create an {@code "error"} server message.
* @param id unique request id
* @param error the error to add as the message payload
*/
public static GraphQlMessage error(String id, GraphQLError error) {
public static GraphQlWebSocketMessage error(String id, GraphQLError error) {
Assert.notNull(error, "GraphQlError is required");
List<Map<String, Object>> errors = Collections.singletonList(error.toSpecification());
return new GraphQlMessage(id, GraphQlMessageType.ERROR, errors);
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.ERROR, errors);
}

/**
* Create a {@code "complete"} server message.
* @param id unique request id
*/
public static GraphQlMessage complete(String id) {
return new GraphQlMessage(id, GraphQlMessageType.COMPLETE, null);
public static GraphQlWebSocketMessage complete(String id) {
return new GraphQlWebSocketMessage(id, GraphQlWebSocketMessageType.COMPLETE, null);
}

/**
* Create a {@code "ping"} client or server message.
* @param payload an optional payload
*/
public static GraphQlMessage ping(@Nullable Object payload) {
return new GraphQlMessage(null, GraphQlMessageType.PING, payload);
public static GraphQlWebSocketMessage ping(@Nullable Object payload) {
return new GraphQlWebSocketMessage(null, GraphQlWebSocketMessageType.PING, payload);
}

/**
* Create a {@code "pong"} client or server message.
* @param payload an optional payload
*/
public static GraphQlMessage pong(@Nullable Object payload) {
return new GraphQlMessage(null, GraphQlMessageType.PONG, payload);
public static GraphQlWebSocketMessage pong(@Nullable Object payload) {
return new GraphQlWebSocketMessage(null, GraphQlWebSocketMessageType.PONG, payload);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* @since 1.0.0
* @see <a href="https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md">GraphQL Over WebSocket Protocol</a>
*/
public enum GraphQlMessageType {
public enum GraphQlWebSocketMessageType {

CONNECTION_INIT("connection_init", false),

Expand All @@ -48,7 +48,7 @@ public enum GraphQlMessageType {
NOT_SPECIFIED("", false);


private static final GraphQlMessageType[] VALUES;
private static final GraphQlWebSocketMessageType[] VALUES;

static {
VALUES = values();
Expand All @@ -60,7 +60,7 @@ public enum GraphQlMessageType {
private final boolean requiresPayload;


GraphQlMessageType(String value, boolean requiresPayload) {
GraphQlWebSocketMessageType(String value, boolean requiresPayload) {
this.value = value;
this.requiresPayload = requiresPayload;
}
Expand All @@ -81,8 +81,8 @@ public boolean doesNotRequirePayload() {
}


public static GraphQlMessageType fromValue(String value) {
for (GraphQlMessageType type : VALUES) {
public static GraphQlWebSocketMessageType fromValue(String value) {
for (GraphQlWebSocketMessageType type : VALUES) {
if (type.value.equals(value)) {
return type;
}
Expand Down
Loading

0 comments on commit 3090328

Please sign in to comment.