Skip to content

Commit

Permalink
Update contribution
Browse files Browse the repository at this point in the history
Closes gh-608
  • Loading branch information
rstoyanchev committed Apr 11, 2024
1 parent 3f5fc1a commit 74688ea
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.graphql.client;

import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
Expand All @@ -26,6 +27,7 @@
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.ClientCodecConfigurer;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.util.DefaultUriBuilderFactory;
Expand All @@ -49,7 +51,8 @@ final class DefaultWebSocketGraphQlClientBuilder

private final CodecConfigurer codecConfigurer;

private long keepalive;
@Nullable
private Duration keepAlive;

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}.
Expand All @@ -58,28 +61,13 @@ final class DefaultWebSocketGraphQlClientBuilder
this(toURI(url), client);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) {
this(toURI(url), client, keepalive);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) {
this(url, client, 0);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) {
this.url = url;
this.webSocketClient = client;
this.codecConfigurer = ClientCodecConfigurer.create();
this.keepalive = keepalive;
}

/**
Expand All @@ -91,7 +79,7 @@ final class DefaultWebSocketGraphQlClientBuilder
this.headers.putAll(transport.getHeaders());
this.webSocketClient = transport.getWebSocketClient();
this.codecConfigurer = transport.getCodecConfigurer();
this.keepalive = transport.getKeepAlive();
this.keepAlive = transport.getKeepAlive();
}


Expand Down Expand Up @@ -128,6 +116,12 @@ public DefaultWebSocketGraphQlClientBuilder codecConfigurer(Consumer<CodecConfig
return this;
}

@Override
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepAlive(Duration keepalive) {
this.keepAlive = keepalive;
return this;
}

@Override
public WebSocketGraphQlClient build() {

Expand All @@ -136,18 +130,12 @@ public WebSocketGraphQlClient build() {
CodecDelegate.findJsonDecoder(this.codecConfigurer));

WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive);
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepAlive);

GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
}

@Override
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepalive(long keepalive) {
this.keepalive = keepalive;
return this;
}

private WebSocketGraphQlClientInterceptor getInterceptor() {

List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.graphql.client;

import java.net.URI;
import java.time.Duration;

import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -64,16 +65,6 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) {
return builder(url, webSocketClient).build();
}

/**
* Create a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) {
return builder(url, webSocketClient).keepalive(keepalive).build();
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -83,16 +74,6 @@ static Builder<?> builder(String url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(String url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -102,31 +83,27 @@ static Builder<?> builder(URI url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(URI url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}


/**
* Builder for a GraphQL over WebSocket client.
* @param <B> the builder type
*/
interface Builder<B extends Builder<B>> extends WebGraphQlClient.Builder<B> {

/**
* Configure how frequently to send ping messages.
* <p>By default, this is not set, and ping messages are not sent.
* @param keepAlive the value to use
* @since 1.3
*/
Builder<B> keepAlive(Duration keepAlive);

/**
* Build the {@code WebSocketGraphQlClient}.
*/
@Override
WebSocketGraphQlClient build();

Builder<B> keepalive(long keepalive);

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,12 +68,13 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {

private final Mono<GraphQlSession> graphQlSessionMono;

private final long keepalive;
@Nullable
private final Duration keepAlive;


WebSocketGraphQlTransport(
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
WebSocketGraphQlClientInterceptor interceptor, long keepalive) {
WebSocketGraphQlClientInterceptor interceptor, @Nullable Duration keepAlive) {

Assert.notNull(url, "URI is required");
Assert.notNull(client, "WebSocketClient is required");
Expand All @@ -83,9 +84,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
this.url = url;
this.headers.putAll((headers != null) ? headers : HttpHeaders.EMPTY);
this.webSocketClient = client;
this.keepalive = keepalive;
this.keepAlive = keepAlive;

this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive);
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepAlive);

this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
Expand Down Expand Up @@ -166,8 +167,9 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
return this.graphQlSessionMono.flatMapMany((session) -> session.executeSubscription(request));
}

public long getKeepAlive() {
return keepalive;
@Nullable
Duration getKeepAlive() {
return this.keepAlive;
}


Expand All @@ -191,15 +193,18 @@ private static class GraphQlSessionHandler implements WebSocketHandler {

private final AtomicBoolean stopped = new AtomicBoolean();

private final long keepalive;
@Nullable
private final Duration keepAlive;


GraphQlSessionHandler(
CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
@Nullable Duration keepAlive) {

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
long keepalive) {
this.codecDelegate = new CodecDelegate(codecConfigurer);
this.interceptor = interceptor;
this.graphQlSessionSink = Sinks.unsafe().one();
this.keepalive = keepalive;
this.keepAlive = keepAlive;
}


Expand Down Expand Up @@ -257,7 +262,7 @@ public Mono<Void> handle(WebSocketSession session) {
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
.map((message) -> this.codecDelegate.encode(session, message)));

Flux<Void> receiveCompletion = session.receive()
Mono<Void> receiveCompletion = session.receive()
.flatMap((webSocketMessage) -> {
if (sessionNotInitialized()) {
try {
Expand Down Expand Up @@ -303,20 +308,22 @@ public Mono<Void> handle(WebSocketSession session) {
}
}
return Mono.empty();
});

if (keepalive > 0) {
Duration keepAliveDuration = Duration.ofSeconds(keepalive);
receiveCompletion = receiveCompletion
.mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration)
.flatMap(i -> {
graphQlSession.sendPing(null);
return Mono.empty();
})
);
})
.mergeWith((this.keepAlive != null) ?
Flux.interval(this.keepAlive, this.keepAlive)
.filter((aLong) -> graphQlSession.checkSentOrReceivedMessagesAndClear())
.doOnNext((aLong) -> graphQlSession.sendPing())
.then() :
Flux.empty())
.then();

if (this.keepAlive != null) {
Flux.interval(this.keepAlive, this.keepAlive)
.filter((aLong) -> graphQlSession.checkSentOrReceivedMessagesAndClear())
.doOnNext((aLong) -> graphQlSession.sendPing())
.subscribe();
}


return Mono.zip(sendCompletion, receiveCompletion.then()).then();
}

Expand Down Expand Up @@ -413,6 +420,8 @@ private static class GraphQlSession {

private final Map<String, RequestState> requestStateMap = new ConcurrentHashMap<>();

private boolean hasReceivedMessages;


GraphQlSession(WebSocketSession webSocketSession) {
this.connection = DisposableConnection.from(webSocketSession);
Expand Down Expand Up @@ -483,11 +492,16 @@ void sendPong(@Nullable Map<String, Object> payload) {
this.requestSink.sendRequest(message);
}

public void sendPing(@Nullable Map<String, Object> payload) {
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload);
void sendPing() {
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(null);
this.requestSink.sendRequest(message);
}

boolean checkSentOrReceivedMessagesAndClear() {
boolean received = this.hasReceivedMessages;
this.hasReceivedMessages = false;
return (this.requestSink.checkSentMessagesAndClear() || received);
}

// Inbound messages

Expand All @@ -504,6 +518,8 @@ void handleNext(GraphQlWebSocketMessage message) {
return;
}

this.hasReceivedMessages = true;

if (requestState instanceof SingleResponseRequestState) {
this.requestStateMap.remove(id);
}
Expand Down Expand Up @@ -631,6 +647,8 @@ private static final class RequestSink {
@Nullable
private FluxSink<GraphQlWebSocketMessage> requestSink;

private boolean hasSentMessages;

private final Flux<GraphQlWebSocketMessage> requestFlux = Flux.create((sink) -> {
Assert.state(this.requestSink == null, "Expected single subscriber only for outbound messages");
this.requestSink = sink;
Expand All @@ -642,9 +660,16 @@ Flux<GraphQlWebSocketMessage> getRequestFlux() {

void sendRequest(GraphQlWebSocketMessage message) {
Assert.state(this.requestSink != null, "Unexpected request before Flux is subscribed to");
this.hasSentMessages = true;
this.requestSink.next(message);
}

boolean checkSentMessagesAndClear() {
boolean result = this.hasSentMessages;
this.hasSentMessages = false;
return result;
}

}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 74688ea

Please sign in to comment.