Skip to content

Commit

Permalink
spring-projects#605 Support for client to send ping messages for subs…
Browse files Browse the repository at this point in the history
…criptions
  • Loading branch information
toby200 committed Feb 8, 2023
1 parent 1ae33a8 commit 8ead7b6
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 19 deletions.
7 changes: 7 additions & 0 deletions spring-graphql-docs/src/docs/asciidoc/includes/client.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ existing `WebSocketGraphQlClient` to create a new instance with customized setti
----

If you'd like the client to send regular graphql ping messages to the server, you can add these by adding `keepalive(long seconds)` to the builder
[source,java,indent=0,subs="verbatim,quotes"]
----
WebSocketGraphQlClient graphQlClient = WebSocketGraphQlClient.builder(url, client)
.keepalive(30)
.build();
----

[[client.websocketgraphqlclient.interceptor]]
==== Interceptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@

package org.springframework.graphql.client;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import reactor.core.publisher.Mono;

import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.ClientCodecConfigurer;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.util.DefaultUriBuilderFactory;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;


/**
Expand All @@ -51,6 +49,7 @@ final class DefaultWebSocketGraphQlClientBuilder

private final CodecConfigurer codecConfigurer;

private long keepalive;

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}.
Expand All @@ -59,13 +58,28 @@ final class DefaultWebSocketGraphQlClientBuilder
this(toURI(url), client);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) {
this(toURI(url), client, keepalive);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) {
this(url, client, 0);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) {
this.url = url;
this.webSocketClient = client;
this.codecConfigurer = ClientCodecConfigurer.create();
this.keepalive = keepalive;
}

/**
Expand All @@ -77,6 +91,7 @@ final class DefaultWebSocketGraphQlClientBuilder
this.headers.putAll(transport.getHeaders());
this.webSocketClient = transport.getWebSocketClient();
this.codecConfigurer = transport.getCodecConfigurer();
this.keepalive = transport.getKeepAlive();
}


Expand Down Expand Up @@ -121,18 +136,24 @@ public WebSocketGraphQlClient build() {
CodecDelegate.findJsonDecoder(this.codecConfigurer));

WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor());
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive);

GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
}

@Override
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepalive(long keepalive) {
this.keepalive = keepalive;
return this;
}

private WebSocketGraphQlClientInterceptor getInterceptor() {

List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()
.filter(interceptor -> interceptor instanceof WebSocketGraphQlClientInterceptor)
.map(interceptor -> (WebSocketGraphQlClientInterceptor) interceptor)
.collect(Collectors.toList());
.toList();

Assert.state(interceptors.size() <= 1,
"Only a single interceptor of type WebSocketGraphQlClientInterceptor may be configured");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) {
return builder(url, webSocketClient).build();
}

/**
* Create a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) {
return builder(url, webSocketClient).keepalive(keepalive).build();
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -73,6 +83,16 @@ static Builder<?> builder(String url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(String url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -82,6 +102,16 @@ static Builder<?> builder(URI url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(URI url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}


/**
* Builder for a GraphQL over WebSocket client.
Expand All @@ -94,6 +124,8 @@ interface Builder<B extends Builder<B>> extends WebGraphQlClient.Builder<B> {
@Override
WebSocketGraphQlClient build();

Builder<B> keepalive(long keepalive);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.graphql.client;

import java.net.URI;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -67,10 +68,12 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {

private final Mono<GraphQlSession> graphQlSessionMono;

private final long keepalive;


WebSocketGraphQlTransport(
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
WebSocketGraphQlClientInterceptor interceptor) {
WebSocketGraphQlClientInterceptor interceptor, long keepalive) {

Assert.notNull(url, "URI is required");
Assert.notNull(client, "WebSocketClient is required");
Expand All @@ -80,8 +83,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
this.url = url;
this.headers.putAll(headers != null ? headers : HttpHeaders.EMPTY);
this.webSocketClient = client;
this.keepalive = keepalive;

this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor);
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive);

this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
Expand Down Expand Up @@ -154,6 +158,10 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
return this.graphQlSessionMono.flatMapMany(session -> session.executeSubscription(request));
}

public long getKeepAlive() {
return keepalive;
}


/**
* Client {@code WebSocketHandler} for GraphQL that deals with WebSocket
Expand All @@ -175,11 +183,15 @@ private static class GraphQlSessionHandler implements WebSocketHandler {

private final AtomicBoolean stopped = new AtomicBoolean();

private final long keepalive;

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) {

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
long keepalive) {
this.codecDelegate = new CodecDelegate(codecConfigurer);
this.interceptor = interceptor;
this.graphQlSessionSink = Sinks.unsafe().one();
this.keepalive = keepalive;
}


Expand Down Expand Up @@ -236,7 +248,7 @@ public Mono<Void> handle(WebSocketSession session) {
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
.map(message -> this.codecDelegate.encode(session, message)));

Mono<Void> receiveCompletion = session.receive()
Flux<Void> receiveCompletion = session.receive()
.flatMap(webSocketMessage -> {
if (sessionNotInitialized()) {
try {
Expand Down Expand Up @@ -277,6 +289,8 @@ public Mono<Void> handle(WebSocketSession session) {
case COMPLETE:
graphQlSession.handleComplete(message);
break;
case PONG:
break;
default:
throw new IllegalStateException(
"Unexpected message type: '" + message.getType() + "'");
Expand All @@ -290,10 +304,21 @@ public Mono<Void> handle(WebSocketSession session) {
}
}
return Mono.empty();
})
.then();
});

if (keepalive > 0) {
Duration keepAliveDuration = Duration.ofSeconds(keepalive);
receiveCompletion = receiveCompletion
.mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration)
.flatMap(i -> {
graphQlSession.sendPing(null);
return Mono.empty();
})
);
}


return Mono.zip(sendCompletion, receiveCompletion).then();
return Mono.zip(sendCompletion, receiveCompletion.then()).then();
}

private boolean sessionNotInitialized() {
Expand Down Expand Up @@ -454,6 +479,11 @@ public void sendPong(@Nullable Map<String, Object> payload) {
this.requestSink.sendRequest(message);
}

public void sendPing(@Nullable Map<String, Object> payload) {
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload);
this.requestSink.sendRequest(message);
}


// Inbound messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ private Publisher<GraphQlWebSocketMessage> handleMessage(GraphQlWebSocketMessage
GraphQlWebSocketMessage.complete(id));
case COMPLETE:
return Flux.empty();
case PING:
return Mono.just(GraphQlWebSocketMessage.pong(null));
default:
return Flux.error(new IllegalStateException("Unexpected message: " + message));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class WebSocketGraphQlTransportTests {
private final static Duration TIMEOUT = Duration.ofSeconds(5);

private static final CodecDelegate CODEC_DELEGATE = new CodecDelegate(ClientCodecConfigurer.create());
public static final int KEEPALIVE = 1;


private final MockGraphQlWebSocketServer mockServer = new MockGraphQlWebSocketServer();
Expand Down Expand Up @@ -185,6 +186,22 @@ void pingHandling() {
GraphQlWebSocketMessage.subscribe("1", new DefaultGraphQlRequest("{Query1}")));
}

@Test
void pingSending() throws InterruptedException {

GraphQlRequest request = this.mockServer.expectOperation("{Sub1}").andStream(Flux.just(this.response1, response2));

StepVerifier.create(this.transport.executeSubscription(request))
.expectNext(this.response1, response2).expectComplete()
.verify(TIMEOUT);
Thread.sleep(KEEPALIVE*1000 + 50); // wait for ping

assertActualClientMessages(
GraphQlWebSocketMessage.connectionInit(null),
GraphQlWebSocketMessage.subscribe("1", request),
GraphQlWebSocketMessage.ping(null));
}

@Test
void start() {
MockGraphQlWebSocketServer handler = new MockGraphQlWebSocketServer();
Expand All @@ -210,7 +227,7 @@ public Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {


WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor);
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor, 3);

transport.start().block(TIMEOUT);

Expand Down Expand Up @@ -324,7 +341,7 @@ void errorDuringResponseHandling() {
private static WebSocketGraphQlTransport createTransport(WebSocketClient client) {
return new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
new WebSocketGraphQlClientInterceptor() {});
new WebSocketGraphQlClientInterceptor() {}, KEEPALIVE);
}

private void assertActualClientMessages(GraphQlWebSocketMessage... expectedMessages) {
Expand Down

0 comments on commit 8ead7b6

Please sign in to comment.