Skip to content

Commit

Permalink
Implement MySqlBatch.Prepare. Fixes #656
Browse files Browse the repository at this point in the history
  • Loading branch information
bgrainger committed Aug 8, 2019
1 parent 3befefe commit 0f95533
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 74 deletions.
3 changes: 0 additions & 3 deletions src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDict
do
{
var command = commandListPosition.Commands[commandListPosition.CommandIndex];
if (command.TryGetPreparedStatements() is object)
throw new InvalidOperationException("Can't send prepared statements as part of a concatenated batch.");

if (Log.IsDebugEnabled())
Log.Debug("Session{0} Preparing command payload; CommandText: {1}", command.Connection.Session.Id, command.CommandText);

Expand Down
2 changes: 0 additions & 2 deletions src/MySqlConnector/Core/IMySqlCommand.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System;
using System.Data;
using System.Threading;
using MySql.Data.MySqlClient;
using MySqlConnector.Utilities;

namespace MySqlConnector.Core
{
Expand Down
58 changes: 56 additions & 2 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,65 @@ public void AbortCancel(ICancellableCommand command)

public bool IsCancelingQuery => m_state == State.CancelingQuery;

public void AddPreparedStatement(string commandText, PreparedStatements preparedStatements)
public async Task PrepareAsync(IMySqlCommand command, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
var statementPreparer = new StatementPreparer(command.CommandText, command.RawParameters, command.CreateStatementPreparerOptions());
var parsedStatements = statementPreparer.SplitStatements();

var columnsAndParameters = new ResizableArray<byte>();
var columnsAndParametersSize = 0;

var preparedStatements = new List<PreparedStatement>(parsedStatements.Statements.Count);
foreach (var statement in parsedStatements.Statements)
{
await SendAsync(new PayloadData(statement.StatementBytes), ioBehavior, cancellationToken).ConfigureAwait(false);
var payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
var response = StatementPrepareResponsePayload.Create(payload.AsSpan());

ColumnDefinitionPayload[] parameters = null;
if (response.ParameterCount > 0)
{
parameters = new ColumnDefinitionPayload[response.ParameterCount];
for (var i = 0; i < response.ParameterCount; i++)
{
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
Utility.Resize(ref columnsAndParameters, columnsAndParametersSize + payload.ArraySegment.Count);
Buffer.BlockCopy(payload.ArraySegment.Array, payload.ArraySegment.Offset, columnsAndParameters.Array, columnsAndParametersSize, payload.ArraySegment.Count);
parameters[i] = ColumnDefinitionPayload.Create(new ResizableArraySegment<byte>(columnsAndParameters, columnsAndParametersSize, payload.ArraySegment.Count));
columnsAndParametersSize += payload.ArraySegment.Count;
}
if (!SupportsDeprecateEof)
{
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
EofPayload.Create(payload.AsSpan());
}
}

ColumnDefinitionPayload[] columns = null;
if (response.ColumnCount > 0)
{
columns = new ColumnDefinitionPayload[response.ColumnCount];
for (var i = 0; i < response.ColumnCount; i++)
{
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
Utility.Resize(ref columnsAndParameters, columnsAndParametersSize + payload.ArraySegment.Count);
Buffer.BlockCopy(payload.ArraySegment.Array, payload.ArraySegment.Offset, columnsAndParameters.Array, columnsAndParametersSize, payload.ArraySegment.Count);
columns[i] = ColumnDefinitionPayload.Create(new ResizableArraySegment<byte>(columnsAndParameters, columnsAndParametersSize, payload.ArraySegment.Count));
columnsAndParametersSize += payload.ArraySegment.Count;
}
if (!SupportsDeprecateEof)
{
payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
EofPayload.Create(payload.AsSpan());
}
}

preparedStatements.Add(new PreparedStatement(response.StatementId, statement, columns, parameters));
}

if (m_preparedStatements is null)
m_preparedStatements = new Dictionary<string, PreparedStatements>();
m_preparedStatements.Add(commandText, preparedStatements);
m_preparedStatements.Add(command.CommandText, new PreparedStatements(preparedStatements, parsedStatements));
}

public PreparedStatements TryGetPreparedStatement(string commandText)
Expand Down
66 changes: 63 additions & 3 deletions src/MySqlConnector/MySql.Data.MySqlClient/MySqlBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private Task<DbDataReader> ExecuteReaderAsync(IOBehavior ioBehavior, Cancellatio
batchCommand.Batch = this;

var payloadCreator = Connection.Session.SupportsComMulti ? BatchedCommandPayloadCreator.Instance :
// TODO: IsPrepared ? SingleCommandPayloadCreator.Instance :
IsPrepared ? SingleCommandPayloadCreator.Instance :
ConcatenatedCommandPayloadCreator.Instance;
return CommandExecutor.ExecuteReaderAsync(BatchCommands, payloadCreator, CommandBehavior.Default, ioBehavior, cancellationToken);
}
Expand All @@ -138,9 +138,19 @@ private Task<DbDataReader> ExecuteReaderAsync(IOBehavior ioBehavior, Cancellatio

public override int Timeout { get; set; }

public override void Prepare() => throw new NotImplementedException();
public override void Prepare()
{
if (!NeedsPrepare(out var exception))
{
if (exception is object)
throw exception;
return;
}

DoPrepareAsync(IOBehavior.Synchronous, default).GetAwaiter().GetResult();
}

public override Task PrepareAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException();
public override Task PrepareAsync(CancellationToken cancellationToken = default) => PrepareAsync(AsyncIOBehavior, cancellationToken);

public override void Cancel() => Connection?.Cancel(this);

Expand Down Expand Up @@ -235,6 +245,56 @@ private bool IsValid(out Exception exception)
return exception is null;
}

private bool NeedsPrepare(out Exception exception)
{
exception = null;
if (Connection is null)
exception = new InvalidOperationException("Connection property must be non-null.");
else if (Connection.State != ConnectionState.Open)
exception = new InvalidOperationException("Connection must be Open; current state is {0}".FormatInvariant(Connection.State));
else if (BatchCommands.Count == 0)
exception = new InvalidOperationException("BatchCommands must contain a command");
else if (Connection?.HasActiveReader ?? false)
exception = new InvalidOperationException("Cannot call Prepare when there is an open DataReader for this command; it must be closed first.");

return exception is null && !Connection.IgnorePrepare;
}

private Task PrepareAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
if (!NeedsPrepare(out var exception))
return exception is null ? Utility.CompletedTask : Utility.TaskFromException(exception);

return DoPrepareAsync(ioBehavior, cancellationToken);
}

private async Task DoPrepareAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
foreach (IMySqlCommand batchCommand in BatchCommands)
{
if (batchCommand.CommandType != CommandType.Text)
throw new NotSupportedException("Only CommandType.Text is currently supported by MySqlBatch.Prepare");
((MySqlBatchCommand) batchCommand).Batch = this;

// don't prepare the same SQL twice
if (Connection.Session.TryGetPreparedStatement(batchCommand.CommandText) is null)
await Connection.Session.PrepareAsync(batchCommand, ioBehavior, cancellationToken).ConfigureAwait(false);
}
}

private bool IsPrepared
{
get
{
foreach (var command in BatchCommands)
{
if (Connection.Session.TryGetPreparedStatement(command.CommandText) is null)
return false;
}
return true;
}
}

private IOBehavior AsyncIOBehavior => Connection?.AsyncIOBehavior ?? IOBehavior.Asynchronous;

readonly int m_commandId;
Expand Down
66 changes: 2 additions & 64 deletions src/MySqlConnector/MySql.Data.MySqlClient/MySqlCommand.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using MySqlConnector.Core;
using MySqlConnector.Protocol;
using MySqlConnector.Protocol.Payloads;
using MySqlConnector.Protocol.Serialization;
using MySqlConnector.Utilities;

Expand Down Expand Up @@ -91,7 +88,7 @@ public override void Prepare()
return;
}

DoPrepareAsync(IOBehavior.Synchronous, default).GetAwaiter().GetResult();
Connection.Session.PrepareAsync(this, IOBehavior.Synchronous, default).GetAwaiter().GetResult();
}

#if !NETSTANDARD2_1 && !NETCOREAPP3_0
Expand All @@ -115,7 +112,7 @@ private Task PrepareAsync(IOBehavior ioBehavior, CancellationToken cancellationT
if (!NeedsPrepare(out var exception))
return exception is null ? Utility.CompletedTask : Utility.TaskFromException(exception);

return DoPrepareAsync(ioBehavior, cancellationToken);
return Connection.Session.PrepareAsync(this, ioBehavior, cancellationToken);
}

private bool NeedsPrepare(out Exception exception)
Expand Down Expand Up @@ -143,65 +140,6 @@ private bool NeedsPrepare(out Exception exception)
return Connection.Session.TryGetPreparedStatement(CommandText) is null;
}

private async Task DoPrepareAsync(IOBehavior ioBehavior, CancellationToken cancellationToken)
{
var statementPreparer = new StatementPreparer(CommandText, m_parameterCollection, ((IMySqlCommand) this).CreateStatementPreparerOptions());
var parsedStatements = statementPreparer.SplitStatements();

var columnsAndParameters = new ResizableArray<byte>();
var columnsAndParametersSize = 0;

var preparedStatements = new List<PreparedStatement>(parsedStatements.Statements.Count);
foreach (var statement in parsedStatements.Statements)
{
await Connection.Session.SendAsync(new PayloadData(statement.StatementBytes), ioBehavior, cancellationToken).ConfigureAwait(false);
var payload = await Connection.Session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
var response = StatementPrepareResponsePayload.Create(payload.AsSpan());

ColumnDefinitionPayload[] parameters = null;
if (response.ParameterCount > 0)
{
parameters = new ColumnDefinitionPayload[response.ParameterCount];
for (var i = 0; i < response.ParameterCount; i++)
{
payload = await Connection.Session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
Utility.Resize(ref columnsAndParameters, columnsAndParametersSize + payload.ArraySegment.Count);
Buffer.BlockCopy(payload.ArraySegment.Array, payload.ArraySegment.Offset, columnsAndParameters.Array, columnsAndParametersSize, payload.ArraySegment.Count);
parameters[i] = ColumnDefinitionPayload.Create(new ResizableArraySegment<byte>(columnsAndParameters, columnsAndParametersSize, payload.ArraySegment.Count));
columnsAndParametersSize += payload.ArraySegment.Count;
}
if (!Connection.Session.SupportsDeprecateEof)
{
payload = await Connection.Session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
EofPayload.Create(payload.AsSpan());
}
}

ColumnDefinitionPayload[] columns = null;
if (response.ColumnCount > 0)
{
columns = new ColumnDefinitionPayload[response.ColumnCount];
for (var i = 0; i < response.ColumnCount; i++)
{
payload = await Connection.Session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
Utility.Resize(ref columnsAndParameters, columnsAndParametersSize + payload.ArraySegment.Count);
Buffer.BlockCopy(payload.ArraySegment.Array, payload.ArraySegment.Offset, columnsAndParameters.Array, columnsAndParametersSize, payload.ArraySegment.Count);
columns[i] = ColumnDefinitionPayload.Create(new ResizableArraySegment<byte>(columnsAndParameters, columnsAndParametersSize, payload.ArraySegment.Count));
columnsAndParametersSize += payload.ArraySegment.Count;
}
if (!Connection.Session.SupportsDeprecateEof)
{
payload = await Connection.Session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false);
EofPayload.Create(payload.AsSpan());
}
}

preparedStatements.Add(new PreparedStatement(response.StatementId, statement, columns, parameters));
}

Connection.Session.AddPreparedStatement(CommandText, new PreparedStatements(preparedStatements, parsedStatements));
}

public override string CommandText
{
get => m_commandText;
Expand Down
43 changes: 43 additions & 0 deletions tests/SideBySide/BatchTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,49 @@ public void ExecuteInvalidSqlBatch()
}
}
}
[Fact]
public void PrepareNeedsConnection()
{
using (var batch = new MySqlBatch
{
BatchCommands =
{
new MySqlBatchCommand("SELECT 1;"),
},
})
{
Assert.Throws<InvalidOperationException>(() => batch.Prepare());
}
}

[Fact]
public void PrepareNeedsOpenConnection()
{
using (var connection = new MySqlConnection(AppConfig.ConnectionString))
using (var batch = new MySqlBatch(connection)
{
BatchCommands =
{
new MySqlBatchCommand("SELECT 1;"),
},
})
{
Assert.Throws<InvalidOperationException>(() => batch.Prepare());
}
}

[Fact]
public void PrepareNeedsCommands()
{
using (var connection = new MySqlConnection(AppConfig.ConnectionString))
{
connection.Open();
using (var batch = new MySqlBatch(connection))
{
Assert.Throws<InvalidOperationException>(() => batch.Prepare());
}
}
}

private static string GetIgnoreCommandTransactionConnectionString() =>
new MySqlConnectionStringBuilder(AppConfig.ConnectionString)
Expand Down

0 comments on commit 0f95533

Please sign in to comment.