Skip to content
This repository has been archived by the owner on Dec 13, 2023. It is now read-only.

Commit

Permalink
Centralized management of envelopes
Browse files Browse the repository at this point in the history
- Move envelope splition from LargeMessageSlicer to FluxCumulateEnvelope
- Move envelope header encoding from WriteSubscriber to FluxCumulateEnvelope
- Move envelope identities generation from SequenceIdProvider to FluxCumulateEnvelope
- Group server messages as LoginClientMessage, which means it is used only by the login phase
- Use QueryFlow to control envelope identities of LoginClientMessage
- Remove SequenceIdProvider
  • Loading branch information
mirromutth committed Nov 11, 2020
1 parent 561b6b6 commit a610517
Show file tree
Hide file tree
Showing 43 changed files with 1,042 additions and 798 deletions.
2 changes: 1 addition & 1 deletion src/main/java/dev/miku/r2dbc/mysql/MySqlConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ public Mono<Boolean> validate(ValidationDepth depth) {
return Mono.just(false);
}

return client.exchange(PingMessage.getInstance(), PING_HANDLER)
return client.exchange(PingMessage.INSTANCE, PING_HANDLER)
.last()
.onErrorResume(e -> {
// `last` maybe emit a NoSuchElementException, exchange maybe emit exception by Netty.
Expand Down
38 changes: 25 additions & 13 deletions src/main/java/dev/miku/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import dev.miku.r2dbc.mysql.message.client.AuthResponse;
import dev.miku.r2dbc.mysql.message.client.ClientMessage;
import dev.miku.r2dbc.mysql.message.client.HandshakeResponse;
import dev.miku.r2dbc.mysql.message.client.LoginClientMessage;
import dev.miku.r2dbc.mysql.message.client.PrepareQueryMessage;
import dev.miku.r2dbc.mysql.message.client.PreparedCloseMessage;
import dev.miku.r2dbc.mysql.message.client.PreparedFetchMessage;
Expand Down Expand Up @@ -654,7 +655,7 @@ final class LoginExchangeable extends FluxExchangeable<Void> {

private static final int HANDSHAKE_VERSION = 10;

private final DirectProcessor<ClientMessage> requests = DirectProcessor.create();
private final DirectProcessor<LoginClientMessage> requests = DirectProcessor.create();

private final Client client;

Expand All @@ -677,6 +678,8 @@ final class LoginExchangeable extends FluxExchangeable<Void> {

private boolean sslCompleted;

private int lastEnvelopeId;

LoginExchangeable(
Client client, SslMode sslMode, String database, String user,
@Nullable CharSequence password, ConnectionContext context
Expand All @@ -702,15 +705,19 @@ public void accept(ServerMessage message, SynchronousSink<Void> sink) {
return;
}

// Ensures it will be initialized only once.
if (handshake) {
handshake = false;
if (message instanceof HandshakeRequest) {
int capabilities = initHandshake((HandshakeRequest) message);
HandshakeRequest request = (HandshakeRequest) message;
int capabilities = initHandshake(request);

lastEnvelopeId = request.getEnvelopeId() + 1;

if ((capabilities & Capabilities.SSL) == 0) {
requests.onNext(createHandshakeResponse(capabilities));
requests.onNext(createHandshakeResponse(lastEnvelopeId, capabilities));
} else {
requests.onNext(SslRequest.from(capabilities, context.getClientCollation().getId()));
requests.onNext(SslRequest.from(lastEnvelopeId, capabilities, context.getClientCollation().getId()));
}
} else {
sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
Expand All @@ -725,20 +732,25 @@ public void accept(ServerMessage message, SynchronousSink<Void> sink) {
sink.complete();
} else if (message instanceof SyntheticSslResponseMessage) {
sslCompleted = true;
requests.onNext(createHandshakeResponse(context.getCapabilities()));
requests.onNext(createHandshakeResponse(++lastEnvelopeId, context.getCapabilities()));
} else if (message instanceof AuthMoreDataMessage) {
if (((AuthMoreDataMessage) message).isFailed()) {
AuthMoreDataMessage msg = (AuthMoreDataMessage) message;
lastEnvelopeId = msg.getEnvelopeId() + 1;

if (msg.isFailed()) {
if (logger.isDebugEnabled()) {
logger.debug("Connection (id {}) fast authentication failed, auto-try to use full authentication", context.getConnectionId());
}
requests.onNext(createAuthResponse("full authentication"));

requests.onNext(createAuthResponse(lastEnvelopeId, "full authentication"));
}
// Otherwise success, wait until OK message or Error message.
} else if (message instanceof ChangeAuthMessage) {
ChangeAuthMessage msg = (ChangeAuthMessage) message;
lastEnvelopeId = msg.getEnvelopeId() + 1;
authProvider = MySqlAuthProvider.build(msg.getAuthType());
salt = msg.getSalt();
requests.onNext(createAuthResponse("change authentication"));
requests.onNext(createAuthResponse(lastEnvelopeId,"change authentication"));
} else {
sink.error(new R2dbcPermissionDeniedException("Unexpected message type '" +
message.getClass().getSimpleName() + "' in login phase"));
Expand All @@ -750,14 +762,14 @@ public void dispose() {
this.requests.onComplete();
}

private AuthResponse createAuthResponse(String phase) {
private AuthResponse createAuthResponse(int envelopeId, String phase) {
MySqlAuthProvider authProvider = getAndNextProvider();

if (authProvider.isSslNecessary() && !sslCompleted) {
throw new R2dbcPermissionDeniedException(formatAuthFails(authProvider.getType(), phase), CLI_SPECIFIC);
}

return new AuthResponse(authProvider.authentication(password, salt, context.getClientCollation()));
return new AuthResponse(envelopeId, authProvider.authentication(password, salt, context.getClientCollation()));
}

private int clientCapabilities(int serverCapabilities) {
Expand Down Expand Up @@ -825,7 +837,7 @@ private MySqlAuthProvider getAndNextProvider() {
return authProvider;
}

private HandshakeResponse createHandshakeResponse(int capabilities) {
private HandshakeResponse createHandshakeResponse(int envelopeId, int capabilities) {
MySqlAuthProvider authProvider = getAndNextProvider();

if (authProvider.isSslNecessary() && !sslCompleted) {
Expand All @@ -841,8 +853,8 @@ private HandshakeResponse createHandshakeResponse(int capabilities) {
authType = MySqlAuthProvider.CACHING_SHA2_PASSWORD;
}

return HandshakeResponse.from(capabilities, context.getClientCollation().getId(), user, authorization,
authType, database, ATTRIBUTES);
return HandshakeResponse.from(envelopeId, capabilities, context.getClientCollation().getId(),
user, authorization, authType, database, ATTRIBUTES);
}

private static String formatAuthFails(String authType, String phase) {
Expand Down
58 changes: 25 additions & 33 deletions src/main/java/dev/miku/r2dbc/mysql/client/MessageDuplexCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

package dev.miku.r2dbc.mysql.client;

import dev.miku.r2dbc.mysql.ConnectionContext;
import dev.miku.r2dbc.mysql.constant.Capabilities;
import dev.miku.r2dbc.mysql.message.client.ClientMessage;
import dev.miku.r2dbc.mysql.message.client.LoginClientMessage;
import dev.miku.r2dbc.mysql.message.client.PrepareQueryMessage;
import dev.miku.r2dbc.mysql.message.client.PreparedFetchMessage;
import dev.miku.r2dbc.mysql.ConnectionContext;
import dev.miku.r2dbc.mysql.message.client.ClientMessage;
import dev.miku.r2dbc.mysql.message.client.SslRequest;
import dev.miku.r2dbc.mysql.message.header.SequenceIdProvider;
import dev.miku.r2dbc.mysql.message.server.ColumnCountMessage;
import dev.miku.r2dbc.mysql.message.server.CompleteMessage;
import dev.miku.r2dbc.mysql.message.server.DecodeContext;
Expand All @@ -32,14 +32,16 @@
import dev.miku.r2dbc.mysql.message.server.ServerMessageDecoder;
import dev.miku.r2dbc.mysql.message.server.ServerStatusMessage;
import dev.miku.r2dbc.mysql.message.server.SyntheticMetadataMessage;
import dev.miku.r2dbc.mysql.util.OperatorUtils;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.util.ReferenceCountUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.util.annotation.Nullable;
import reactor.core.publisher.Flux;

import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -54,10 +56,7 @@ final class MessageDuplexCodec extends ChannelDuplexHandler {

private static final Logger logger = LoggerFactory.getLogger(MessageDuplexCodec.class);

private DecodeContext decodeContext = DecodeContext.connection();

@Nullable
private SequenceIdProvider.Linkable linkableIdProvider;
private DecodeContext decodeContext = DecodeContext.login();

private final ConnectionContext context;

Expand All @@ -73,23 +72,11 @@ final class MessageDuplexCodec extends ChannelDuplexHandler {
this.requestQueue = requireNonNull(requestQueue, "requestQueue must not be null");
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof Lifecycle) {
if (Lifecycle.COMMAND == evt) {
// Message sequence id always from 0 in command phase.
this.linkableIdProvider = null;
}
} else {
super.userEventTriggered(ctx, evt);
}
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof ByteBuf) {
DecodeContext context = this.decodeContext;
ServerMessage message = decoder.decode((ByteBuf) msg, this.context, context, this.linkableIdProvider);
ServerMessage message = this.decoder.decode((ByteBuf) msg, this.context, context);

if (message != null) {
handleDecoded(ctx, message);
Expand All @@ -107,8 +94,23 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof ClientMessage) {
((ClientMessage) msg).encode(ctx.alloc(), this.context)
.subscribe(WriteSubscriber.create(ctx, promise, this.linkableIdProvider));
ByteBufAllocator allocator = ctx.alloc();

Flux<ByteBuf> encoded;
int envelopeId;

if (msg instanceof LoginClientMessage) {
LoginClientMessage message = (LoginClientMessage) msg;

encoded = Flux.from(message.encode(allocator, this.context));
envelopeId = message.getEnvelopeId();
} else {
encoded = Flux.from(((ClientMessage) msg).encode(allocator, this.context));
envelopeId = 0;
}

OperatorUtils.cumulateEnvelope(encoded, allocator, envelopeId)
.subscribe(new WriteSubscriber(ctx, promise));

if (msg instanceof PrepareQueryMessage) {
setDecodeContext(DecodeContext.prepareQuery());
Expand Down Expand Up @@ -139,16 +141,6 @@ public void channelInactive(ChannelHandlerContext ctx) {
ctx.fireChannelInactive();
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) {
this.linkableIdProvider = SequenceIdProvider.atomic();
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
this.linkableIdProvider = null;
}

private void handleDecoded(ChannelHandlerContext ctx, ServerMessage msg) {
if (msg instanceof ServerStatusMessage) {
this.context.setServerStatuses(((ServerStatusMessage) msg).getServerStatuses());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public Mono<Void> close() {
return;
}

requestQueue.submit(RequestTask.wrap(sink, Mono.fromRunnable(() -> requestProcessor.onNext(ExitMessage.getInstance()))));
requestQueue.submit(RequestTask.wrap(sink, Mono.fromRunnable(() -> requestProcessor.onNext(ExitMessage.INSTANCE))));
}).flatMap(identity()).onErrorResume(e -> {
logger.error("Exit message sending failed, force closing", e);
return Mono.empty();
Expand Down
20 changes: 1 addition & 19 deletions src/main/java/dev/miku/r2dbc/mysql/client/WriteSubscriber.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package dev.miku.r2dbc.mysql.client;

import dev.miku.r2dbc.mysql.constant.Envelopes;
import dev.miku.r2dbc.mysql.message.header.SequenceIdProvider;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import org.reactivestreams.Subscription;
import reactor.core.CoreSubscriber;
import reactor.util.annotation.Nullable;

/**
* An implementation of {@link CoreSubscriber} for {@link ChannelHandlerContext} write
Expand All @@ -37,12 +34,9 @@ final class WriteSubscriber implements CoreSubscriber<ByteBuf> {

private final ChannelPromise promise;

private final SequenceIdProvider provider;

private WriteSubscriber(ChannelHandlerContext ctx, ChannelPromise promise, SequenceIdProvider provider) {
WriteSubscriber(ChannelHandlerContext ctx, ChannelPromise promise) {
this.ctx = ctx;
this.promise = promise;
this.provider = provider;
}

@Override
Expand All @@ -52,9 +46,6 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(ByteBuf buf) {
ctx.write(ctx.alloc().buffer(Envelopes.PART_HEADER_SIZE, Envelopes.PART_HEADER_SIZE)
.writeMediumLE(buf.readableBytes())
.writeByte(provider.next()));
ctx.write(buf);
}

Expand All @@ -71,13 +62,4 @@ public void onComplete() {
promise.setSuccess();
ctx.flush();
}

static WriteSubscriber create(ChannelHandlerContext ctx, ChannelPromise promise, @Nullable SequenceIdProvider provider) {
if (provider == null) {
// Used by this message ByteBuf stream only, can be unsafe.
provider = SequenceIdProvider.unsafe();
}

return new WriteSubscriber(ctx, promise, provider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package dev.miku.r2dbc.mysql.message.client;

import dev.miku.r2dbc.mysql.ConnectionContext;
import io.netty.buffer.ByteBuf;

import java.util.Arrays;
Expand All @@ -26,40 +25,53 @@
/**
* A message that contains only an authentication, used by full authentication or change authentication response.
*/
public final class AuthResponse extends EnvelopeClientMessage {
public final class AuthResponse extends SizedClientMessage implements LoginClientMessage {

private final int envelopeId;

private final byte[] authentication;

public AuthResponse(byte[] authentication) {
public AuthResponse(int envelopeId, byte[] authentication) {
this.envelopeId = envelopeId;
this.authentication = requireNonNull(authentication, "authentication must not be null");
}

@Override
public int getEnvelopeId() {
return envelopeId;
}

@Override
protected int size() {
return authentication.length;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof AuthResponse)) {
if (o == null || getClass() != o.getClass()) {
return false;
}

AuthResponse that = (AuthResponse) o;

return Arrays.equals(authentication, that.authentication);
return envelopeId == that.envelopeId && Arrays.equals(authentication, that.authentication);
}

@Override
public int hashCode() {
return Arrays.hashCode(authentication);
return 31 * envelopeId + Arrays.hashCode(authentication);
}

@Override
public String toString() {
return "AuthResponse{authentication=REDACTED}";
return "AuthResponse{envelopeId=" + envelopeId + ", authentication=REDACTED}";
}

@Override
protected void writeTo(ByteBuf buf, ConnectionContext context) {
protected void writeTo(ByteBuf buf) {
buf.writeBytes(authentication);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@
import dev.miku.r2dbc.mysql.ConnectionContext;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import org.reactivestreams.Publisher;
import reactor.core.CorePublisher;
import reactor.core.publisher.Flux;

/**
* A message sent from a MySQL client to a MySQL server.
*/
public interface ClientMessage {

/**
* Encode a message into a {@link ByteBuf} data buffer without envelope header.
* Encode a message into {@link ByteBuf}s.
*
* @param allocator the {@link ByteBufAllocator} to use to get a {@link ByteBuf} data buffer to write into.
* @param allocator the {@link ByteBufAllocator} that use to get {@link ByteBuf} to write into.
* @param context current MySQL connection context
* @return a {@link Publisher} that produces {@link ByteBuf}s sliced by {@code Envelopes.MAX_ENVELOPE_SIZE}, which containing the encoded message.
* @return a {@link Flux} that's produces the encoded {@link ByteBuf}s.
* @throws IllegalArgumentException if {@code allocator} or {@code context} is {@code null}.
*/
Publisher<ByteBuf> encode(ByteBufAllocator allocator, ConnectionContext context);
CorePublisher<ByteBuf> encode(ByteBufAllocator allocator, ConnectionContext context);
}
Loading

0 comments on commit a610517

Please sign in to comment.