Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce retained references to ConnectionContext #260

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ public final class Capability {
TRANSACTIONS | SECURE_SALT | MULTI_STATEMENTS | MULTI_RESULTS | PS_MULTI_RESULTS |
PLUGIN_AUTH | CONNECT_ATTRS | VAR_INT_SIZED_AUTH | SESSION_TRACK | DEPRECATE_EOF | ZSTD_COMPRESS;

/**
* The default capabilities for a MySQL connection. It contains all client supported capabilities.
*/
public static final Capability DEFAULT = new Capability(ALL_SUPPORTED);

private final long bitmap;

/**
Expand Down Expand Up @@ -373,7 +378,8 @@ private Capability(long bitmap) {
* @return the {@link Capability} without unknown flags.
*/
public static Capability of(long capabilities) {
return new Capability(capabilities & ALL_SUPPORTED);
long c = capabilities & ALL_SUPPORTED;
return c == ALL_SUPPORTED ? DEFAULT : new Capability(c);
}

static final class Builder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,13 @@
/**
* The MySQL connection context considers the behavior of server or client.
* <p>
* WARNING: Do NOT change any data outside of this project, try to configure {@code ConnectionFactoryOptions}
* or {@code MySqlConnectionConfiguration} to control connection context and client behavior.
* WARNING: Do NOT change any data outside of this project, try to configure {@code ConnectionFactoryOptions} or
* {@code MySqlConnectionConfiguration} to control connection context and client behavior.
*/
public final class ConnectionContext implements CodecContext {

private static final ServerVersion NONE_VERSION = ServerVersion.create(0, 0, 0);

private volatile int connectionId = -1;

private volatile ServerVersion serverVersion = NONE_VERSION;

private final ZeroDateOption zeroDateOption;

@Nullable
Expand All @@ -50,20 +46,25 @@ public final class ConnectionContext implements CodecContext {

private final boolean preserveInstants;

private int connectionId = -1;

private ServerVersion serverVersion = NONE_VERSION;

private Capability capability = Capability.DEFAULT;

@Nullable
private ZoneId timeZone;

private boolean lockWaitTimeoutSupported = false;

/**
* Assume that the auto commit is always turned on, it will be set after handshake V10 request message, or
* OK message which means handshake V9 completed.
* Assume that the auto commit is always turned on, it will be set after handshake V10 request message, or OK
* message which means handshake V9 completed.
* <p>
* It would be updated multiple times, so {@code volatile} is required.
*/
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;

@Nullable
private volatile Capability capability = null;

ConnectionContext(
ZeroDateOption zeroDateOption,
@Nullable Path localInfilePath,
Expand All @@ -78,33 +79,37 @@ public final class ConnectionContext implements CodecContext {
this.timeZone = timeZone;
}

/**
* Get the connection identifier that is specified by server.
*
* @return the connection identifier.
*/
public int getConnectionId() {
return connectionId;
}

/**
* Initializes this context.
*
* @param connectionId the connection identifier that is specified by server.
* @param version the server version.
* @param capability the connection capabilities.
*/
public void init(int connectionId, ServerVersion version, Capability capability) {
void init(int connectionId, ServerVersion version, Capability capability) {
this.connectionId = connectionId;
this.serverVersion = version;
this.capability = capability;
}

/**
* Get the connection identifier that is specified by server.
*
* @return the connection identifier.
*/
public int getConnectionId() {
return connectionId;
}

@Override
public ServerVersion getServerVersion() {
return serverVersion;
}

public Capability getCapability() {
return capability;
}

@Override
public CharCollation getClientCollation() {
return CharCollation.clientCharCollation();
Expand All @@ -123,7 +128,7 @@ public ZoneId getTimeZone() {
return timeZone;
}

public boolean isTimeZoneInitialized() {
boolean isTimeZoneInitialized() {
return timeZone != null;
}

Expand All @@ -133,9 +138,9 @@ public boolean isMariaDb() {
return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb();
}

void setTimeZone(ZoneId timeZone) {
void initTimeZone(ZoneId timeZone) {
if (isTimeZoneInitialized()) {
throw new IllegalStateException("Server timezone have been initialized");
throw new IllegalStateException("Connection timezone have been initialized");
}
this.timeZone = timeZone;
}
Expand Down Expand Up @@ -176,7 +181,7 @@ public boolean isLockWaitTimeoutSupported() {
/**
* Enables lock wait timeout supported when loading session variables.
*/
public void enableLockWaitTimeoutSupported() {
void enableLockWaitTimeoutSupported() {
this.lockWaitTimeoutSupported = true;
}

Expand All @@ -197,13 +202,4 @@ public short getServerStatuses() {
public void setServerStatuses(short serverStatuses) {
this.serverStatuses = serverStatuses;
}

/**
* Get the connection capability. Should use it after this context initialized.
*
* @return the connection capability.
*/
public Capability getCapability() {
return capability;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,11 @@ final class MySqlBatchingBatch implements MySqlBatch {

private final Codecs codecs;

private final ConnectionContext context;

private final StringJoiner queries = new StringJoiner(";");

MySqlBatchingBatch(Client client, Codecs codecs, ConnectionContext context) {
MySqlBatchingBatch(Client client, Codecs codecs) {
this.client = requireNonNull(client, "client must not be null");
this.codecs = requireNonNull(codecs, "codecs must not be null");
this.context = requireNonNull(context, "context must not be null");
}

@Override
Expand All @@ -65,7 +62,7 @@ public MySqlBatch add(String sql) {
@Override
public Flux<MySqlResult> execute() {
return QueryFlow.execute(client, getSql())
.map(messages -> MySqlSegmentResult.toResult(false, codecs, context, null, messages));
.map(messages -> MySqlSegmentResult.toResult(false, client, codecs, null, messages));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
final String user,
final SslMode sslMode,
final Set<CompressionAlgorithm> compressionAlgorithms,
final int zstdCompressionLevel,
final int zstdLevel,
final ConnectionContext context,
final Extensions extensions,
final List<String> sessionVariables,
Expand All @@ -163,8 +163,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
.flatMap(client -> {
// Lazy init database after handshake/login
String db = createDbIfNotExist ? "" : database;
return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms,
zstdCompressionLevel, context);
return QueryFlow.login(client, sslMode, db, user, password, compressionAlgorithms, zstdLevel);
})
.flatMap(client -> {
ByteBufAllocator allocator = client.getByteBufAllocator();
Expand All @@ -175,7 +174,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
extensions.forEach(CodecRegistrar.class, registrar ->
registrar.register(allocator, builder));

return MySqlSimpleConnection.init(client, builder.build(), context, db, queryCache.get(),
return MySqlSimpleConnection.init(client, builder.build(), db, queryCache.get(),
prepareCache, sessionVariables, prepare);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.asyncer.r2dbc.mysql.api.MySqlRow;
import io.asyncer.r2dbc.mysql.api.MySqlRowMetadata;
import io.asyncer.r2dbc.mysql.codec.CodecContext;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.message.FieldValue;
import io.r2dbc.spi.Row;
Expand All @@ -42,10 +43,13 @@ final class MySqlDataRow implements MySqlRow {
*/
private final boolean binary;

private final ConnectionContext context;
/**
* It can be retained because it is provided by the executed connection instead of the current connection.
*/
private final CodecContext context;

MySqlDataRow(FieldValue[] fields, MySqlRowDescriptor rowMetadata, Codecs codecs, boolean binary,
ConnectionContext context) {
CodecContext context) {
this.fields = requireNonNull(fields, "fields must not be null");
this.rowMetadata = requireNonNull(rowMetadata, "rowMetadata must not be null");
this.codecs = requireNonNull(codecs, "codecs must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.asyncer.r2dbc.mysql.api.MySqlResult;
import io.asyncer.r2dbc.mysql.api.MySqlRow;
import io.asyncer.r2dbc.mysql.client.Client;
import io.asyncer.r2dbc.mysql.codec.Codecs;
import io.asyncer.r2dbc.mysql.internal.util.NettyBufferUtils;
import io.asyncer.r2dbc.mysql.internal.util.OperatorUtils;
Expand Down Expand Up @@ -53,8 +54,8 @@
/**
* An implementation of {@link MySqlResult} representing the results of a query against the MySQL database.
* <p>
* A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment},
* see also {@link MySqlOkSegment}.
* A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment}, see also
* {@link MySqlOkSegment}.
*/
final class MySqlSegmentResult implements MySqlResult {

Expand Down Expand Up @@ -156,15 +157,15 @@ public <T> Flux<T> flatMap(Function<Result.Segment, ? extends Publisher<? extend
});
}

static MySqlResult toResult(boolean binary, Codecs codecs, ConnectionContext context,
@Nullable String syntheticKeyName, Flux<ServerMessage> messages) {
static MySqlResult toResult(boolean binary, Client client, Codecs codecs,
@Nullable String syntheticKeyName, Flux<ServerMessage> messages) {
requireNonNull(client, "client must not be null");
requireNonNull(codecs, "codecs must not be null");
requireNonNull(context, "context must not be null");
requireNonNull(messages, "messages must not be null");

return new MySqlSegmentResult(OperatorUtils.discardOnCancel(messages)
.doOnDiscard(ReferenceCounted.class, ReferenceCounted::release)
.handle(new MySqlSegments(binary, codecs, context, syntheticKeyName)));
.handle(new MySqlSegments(binary, client, codecs, syntheticKeyName)));
}

private static final class MySqlMessage implements Message {
Expand Down Expand Up @@ -269,9 +270,9 @@ private static final class MySqlSegments implements BiConsumer<ServerMessage, Sy

private final boolean binary;

private final Codecs codecs;
private final Client client;

private final ConnectionContext context;
private final Codecs codecs;

@Nullable
private final String syntheticKeyName;
Expand All @@ -280,11 +281,10 @@ private static final class MySqlSegments implements BiConsumer<ServerMessage, Sy

private MySqlRowDescriptor rowMetadata;

private MySqlSegments(boolean binary, Codecs codecs, ConnectionContext context,
@Nullable String syntheticKeyName) {
private MySqlSegments(boolean binary, Client client, Codecs codecs, @Nullable String syntheticKeyName) {
this.binary = binary;
this.client = client;
this.codecs = codecs;
this.context = context;
this.syntheticKeyName = syntheticKeyName;
}

Expand All @@ -310,7 +310,7 @@ public void accept(ServerMessage message, SynchronousSink<Segment> sink) {
ReferenceCountUtil.safeRelease(message);
}

sink.next(new MySqlRowSegment(fields, metadata, codecs, binary, context));
sink.next(new MySqlRowSegment(fields, metadata, codecs, binary, client.getContext()));
} else if (message instanceof SyntheticMetadataMessage) {
DefinitionMetadataMessage[] metadataMessages = ((SyntheticMetadataMessage) message).unwrap();

Expand All @@ -322,7 +322,7 @@ public void accept(ServerMessage message, SynchronousSink<Segment> sink) {
} else if (message instanceof OkMessage) {
OkMessage msg = (OkMessage) message;

if (MySqlStatementSupport.supportReturning(context) && msg.isEndOfRows()) {
if (MySqlStatementSupport.supportReturning(client.getContext()) && msg.isEndOfRows()) {
sink.next(new MySqlUpdateCount(rowCount.getAndSet(0)));
} else {
long rows = msg.getAffectedRows();
Expand Down
Loading
Loading