Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SqlSequentialStream multipacket reads stalling and add covering test #603

Merged
merged 1 commit into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4517,22 +4517,31 @@ private Task<int> GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b
SetTimeout(_defaultTimeoutMilliseconds);

// Try to read without any continuations (all the data may already be in the stateObj's buffer)
if (!TryGetBytesInternalSequential(context.columnIndex, context.buffer, context.index, context.length, out bytesRead))
bool filledBuffer = context._reader.TryGetBytesInternalSequential(
context.columnIndex,
context.buffer,
context.index + context.totalBytesRead,
context.length - context.totalBytesRead,
out bytesRead
);
context.totalBytesRead += bytesRead;
Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required");

if (!filledBuffer)
{
// This will be the 'state' for the callback
int totalBytesRead = bytesRead;

if (!isContinuation)
{
// This is the first async operation which is happening - setup the _currentTask and timeout
Debug.Assert(context._source==null, "context._source should not be non-null when trying to change to async");
source = new TaskCompletionSource<int>();
Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null);
if (original != null)
{
source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending()));
return source.Task;
}

context._source = source;
// Check if cancellation due to close is requested (this needs to be done after setting _currentTask)
if (_cancelAsyncOnCloseToken.IsCancellationRequested)
{
Expand Down Expand Up @@ -4561,7 +4570,7 @@ private Task<int> GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b
}
else
{
Debug.Assert(context._source != null, "context.source should not be null when continuing");
Debug.Assert(context._source != null, "context._source should not be null when continuing");
// setup for cleanup/completing
retryTask.ContinueWith(
continuationAction: AAsyncCallContext<int>.s_completeCallback,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,97 @@ public static void RunAllTestsForSingleServer_TCP()
RunAllTestsForSingleServer(DataTestUtility.TCPConnectionString);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static async Task AsyncMultiPacketStreamRead()
{
int packetSize = 514; // force small packet size so we can quickly check multi packet reads

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString);
builder.PacketSize = 514;
string connectionString = builder.ToString();

byte[] inputData = null;
byte[] outputData = null;
string tableName = DataTestUtility.GetUniqueNameForSqlServer("data");

using (SqlConnection connection = new SqlConnection(connectionString))
{
await connection.OpenAsync();

try
{
inputData = CreateBinaryTable(connection, tableName, packetSize);

using (SqlCommand command = new SqlCommand($"SELECT foo FROM {tableName}", connection))
using (SqlDataReader reader = await command.ExecuteReaderAsync(System.Data.CommandBehavior.SequentialAccess))
{
await reader.ReadAsync();

using (Stream stream = reader.GetStream(0))
using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(60)))
using (MemoryStream memory = new MemoryStream(16 * 1024))
{
await stream.CopyToAsync(memory, 37, cancellationTokenSource.Token); // prime number sized buffer to cause many cross packet partial reads
outputData = memory.ToArray();
}
}
}
finally
{
DataTestUtility.DropTable(connection, tableName);
}
}

Assert.NotNull(outputData);
int sharedLength = Math.Min(inputData.Length, outputData.Length);
if (sharedLength < outputData.Length)
{
Assert.False(true, $"output is longer than input, input={inputData.Length} bytes, output={outputData.Length} bytes");
}
if (sharedLength < inputData.Length)
{
Assert.False(true, $"input is longer than output, input={inputData.Length} bytes, output={outputData.Length} bytes");
}
for (int index = 0; index < sharedLength; index++)
{
if (inputData[index] != outputData[index]) // avoid formatting the output string unless there is a difference
{
Assert.True(false, $"input and output differ at index {index}, input={inputData[index]}, output={outputData[index]}");
}
}

}

private static byte[] CreateBinaryTable(SqlConnection connection, string tableName, int packetSize)
{
byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 };
byte[] data = new byte[packetSize * 10];
int position = 0;
while (position < data.Length)
{
int copyCount = Math.Min(pattern.Length, data.Length - position);
Array.Copy(pattern, 0, data, position, copyCount);
position += copyCount;
}

using (var cmd = connection.CreateCommand())
{
cmd.CommandText = $@"
IF OBJECT_ID('dbo.{tableName}', 'U') IS NOT NULL
DROP TABLE {tableName};
CREATE TABLE {tableName} (id INT, foo VARBINARY(MAX))
";
cmd.ExecuteNonQuery();

cmd.CommandText = $"INSERT INTO {tableName} (id, foo) VALUES (@id, @foo)";
cmd.Parameters.AddWithValue("id", 1);
cmd.Parameters.AddWithValue("foo", data);
cmd.ExecuteNonQuery();
}

return data;
}

private static void RunAllTestsForSingleServer(string connectionString, bool usingNamePipes = false)
{
RowBuffer(connectionString);
Expand Down Expand Up @@ -1811,7 +1902,7 @@ private static void TestXEventsStreaming(string connectionString)
SqlDataReader reader = cmd.ExecuteReader(System.Data.CommandBehavior.SequentialAccess);
for (int i = 0; i < streamXeventCount && reader.Read(); i++)
{
Int32 colType = reader.GetInt32(0);
int colType = reader.GetInt32(0);
int cb = (int)reader.GetBytes(1, 0, null, 0, 0);

byte[] bytes = new byte[cb];
Expand Down