From 492fe7b44d3325effd789c7b597fad9507afd51b Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Fri, 15 Feb 2019 20:51:23 +0000 Subject: [PATCH 1/4] rework SNIPacket and usage --- .../System/Data/SqlClient/SNI/SNIHandle.cs | 2 + .../Data/SqlClient/SNI/SNIMarsConnection.cs | 13 +- .../Data/SqlClient/SNI/SNIMarsHandle.cs | 83 +++---- .../System/Data/SqlClient/SNI/SNINpHandle.cs | 2 +- .../SqlClient/SNI/SNIPacket.NetCoreApp.cs | 8 +- .../SqlClient/SNI/SNIPacket.NetStandard.cs | 8 +- .../System/Data/SqlClient/SNI/SNIPacket.cs | 223 +++++++----------- .../src/System/Data/SqlClient/SNI/SNIProxy.cs | 15 +- .../System/Data/SqlClient/SNI/SNITcpHandle.cs | 2 +- .../Data/SqlClient/SNI/SslOverTdsStream.cs | 92 +++++--- .../Data/SqlClient/TdsParserStateObject.cs | 4 +- .../TdsParserStateObjectFactory.Windows.cs | 8 +- .../SqlClient/TdsParserStateObjectManaged.cs | 149 +++--------- .../SqlClient/TdsParserStateObjectNative.cs | 2 +- 14 files changed, 234 insertions(+), 377 deletions(-) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs index b38cb2411667..b9370fa7a981 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs @@ -84,6 +84,8 @@ internal abstract class SNIHandle /// public abstract Guid ConnectionId { get; } + public virtual bool SMUXEnabled => false; + #if DEBUG /// /// Test handle for killing underlying connection diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs index 989e9a2b94a9..79a17795b8da 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -105,6 +105,11 @@ public uint SendAsync(SNIPacket packet, SNIAsyncCallback callback) /// SNI error code public uint ReceiveAsync(ref SNIPacket packet) { + if (packet != null) + { + packet.Release(); + packet = null; + } lock (this) { return _lowerHandle.ReceiveAsync(ref packet); @@ -133,7 +138,7 @@ public void HandleReceiveError(SNIPacket packet) { handle.HandleReceiveError(packet); } - packet?.Dispose(); + packet?.Release(); } /// @@ -183,8 +188,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) if (bytesTaken == 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) @@ -214,8 +217,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) if (_dataBytesLeft > 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) @@ -271,8 +272,6 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) { if (packet.DataLeft == 0) { - packet.Dispose(); - packet = null; sniErrorCode = ReceiveAsync(ref packet); if (sniErrorCode == TdsEnums.SNI_SUCCESS_IO_PENDING) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs index 6ba903947386..0d767d443dab 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -11,7 +11,7 @@ namespace System.Data.SqlClient.SNI /// /// MARS handle /// - internal class SNIMarsHandle : SNIHandle + internal sealed class SNIMarsHandle : SNIHandle { private const uint ACK_THRESHOLD = 2; @@ -33,27 +33,11 @@ internal class SNIMarsHandle : SNIHandle private uint _sequenceNumber; private SNIError _connectionError; - /// - /// Connection ID - /// - public override Guid ConnectionId - { - get - { - return _connectionId; - } - } + public override Guid ConnectionId => _connectionId; - /// - /// Handle status - /// - public override uint Status - { - get - { - return _status; - } - } + public override uint Status => _status; + + public override bool SMUXEnabled => true; /// /// Dispose object @@ -93,20 +77,18 @@ public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object call /// SMUX header flags private void SendControlPacket(SNISMUXFlags flags) { - Span headerBytes = stackalloc byte[SNISMUXHeader.HEADER_LENGTH]; + SNIPacket packet = new SNIPacket(0,reserveMuxHeader:true); lock (this) { - GetSMUXHeaderBytes(0, flags, headerBytes); + SetupSMUXHeader(0, flags); + packet.SetHeader(_currentHeader); } - - SNIPacket packet = new SNIPacket(SNISMUXHeader.HEADER_LENGTH); - packet.AppendData(headerBytes); - _connection.Send(packet); } - private void GetSMUXHeaderBytes(int length, SNISMUXFlags flags, Span bytes) + private void SetupSMUXHeader(int length, SNISMUXFlags flags) { + Debug.Assert(Monitor.IsEntered(this), "must take lock on self before updating mux header"); _currentHeader.SMID = 83; _currentHeader.flags = (byte)flags; _currentHeader.sessionId = _sessionId; @@ -114,27 +96,20 @@ private void GetSMUXHeaderBytes(int length, SNISMUXFlags flags, Span bytes _currentHeader.sequenceNumber = ((flags == SNISMUXFlags.SMUX_FIN) || (flags == SNISMUXFlags.SMUX_ACK)) ? _sequenceNumber - 1 : _sequenceNumber++; _currentHeader.highwater = _receiveHighwater; _receiveHighwaterLastAck = _currentHeader.highwater; - - _currentHeader.Write(bytes); } /// /// Generate a packet with SMUX header /// /// SNI packet - /// Encapsulated SNI packet - private SNIPacket GetSMUXEncapsulatedPacket(SNIPacket packet) + /// The packet with the SMUx header set. + private SNIPacket SetPacketSMUXHeader(SNIPacket packet) { - uint xSequenceNumber = _sequenceNumber; - Span header = stackalloc byte[SNISMUXHeader.HEADER_LENGTH]; - GetSMUXHeaderBytes(packet.Length, SNISMUXFlags.SMUX_DATA, header); + Debug.Assert(packet.MuxHeaderReserved, "attempting to mux packet without mux reservation"); - - SNIPacket smuxPacket = new SNIPacket(SNISMUXHeader.HEADER_LENGTH + packet.Length); - smuxPacket.AppendData(header); - smuxPacket.AppendPacket(packet); - packet.Dispose(); - return smuxPacket; + SetupSMUXHeader(packet.Length, SNISMUXFlags.SMUX_DATA); + packet.SetHeader(_currentHeader); + return packet; } /// @@ -144,6 +119,8 @@ private SNIPacket GetSMUXEncapsulatedPacket(SNIPacket packet) /// SNI error code public override uint Send(SNIPacket packet) { + Debug.Assert(packet.MuxHeaderReserved, "attempting to send muxed packet without mux reservation in Send"); + while (true) { lock (this) @@ -161,9 +138,13 @@ public override uint Send(SNIPacket packet) _ackEvent.Reset(); } } - SNIPacket encapsulatedPacket = GetSMUXEncapsulatedPacket(packet); - return _connection.Send(encapsulatedPacket); + SNIPacket muxedPacket = null; + lock (this) + { + muxedPacket = SetPacketSMUXHeader(packet); + } + return _connection.Send(muxedPacket); } /// @@ -174,6 +155,7 @@ public override uint Send(SNIPacket packet) /// SNI error code private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) { + Debug.Assert(packet.MuxHeaderReserved, "attempting to send muxed packet without mux reservation in InternalSendAsync"); lock (this) { if (_sequenceNumber >= _sendHighwater) @@ -181,18 +163,9 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) return TdsEnums.SNI_QUEUE_FULL; } - SNIPacket encapsulatedPacket = GetSMUXEncapsulatedPacket(packet); - - if (callback != null) - { - encapsulatedPacket.SetCompletionCallback(callback); - } - else - { - encapsulatedPacket.SetCompletionCallback(HandleSendComplete); - } - - return _connection.SendAsync(encapsulatedPacket, callback); + SNIPacket muxedPacket = SetPacketSMUXHeader(packet); + muxedPacket.SetCompletionCallback(callback??HandleSendComplete); + return _connection.SendAsync(muxedPacket, callback); } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs index 67581e8d32c4..526289f2c01f 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs @@ -15,7 +15,7 @@ namespace System.Data.SqlClient.SNI /// /// Named Pipe connection handle /// - internal class SNINpHandle : SNIHandle + internal sealed class SNINpHandle : SNIHandle { internal const string DefaultPipePath = @"sql\query"; // e.g. \\HOSTNAME\pipe\sql\query private const int MAX_PIPE_INSTANCES = 255; diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs index 6e5cab47e390..7e149c77b267 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs @@ -45,7 +45,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - ValueTask vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None); + ValueTask vt = stream.ReadAsync(new Memory(_data, _header, _capacity), CancellationToken.None); if (vt.IsCompletedSuccessfully) { @@ -88,11 +88,11 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfter) { - packet.Dispose(); + packet.Release(); } } - ValueTask vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None); + ValueTask vt = stream.WriteAsync(new Memory(_data, _header, _length), CancellationToken.None); if (vt.IsCompletedSuccessfully) { @@ -103,7 +103,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfterWriteAsync) { - Dispose(); + Release(); } // Completed diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs index bfa48ac17b61..6906ff8d8638 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs @@ -45,7 +45,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - Task t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None); + Task t = stream.ReadAsync(_data, _header, _capacity, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { @@ -88,11 +88,11 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfter) { - packet.Dispose(); + packet.Release(); } } - Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None); + Task t = stream.WriteAsync(_data, _header, _length, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { @@ -103,7 +103,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider if (disposeAfterWriteAsync) { - Dispose(); + Release(); } // Completed diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs index cfc53aa0cbbb..8bacb0bf1d34 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Buffers; +using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; @@ -12,37 +13,28 @@ namespace System.Data.SqlClient.SNI /// /// SNI Packet /// - internal partial class SNIPacket : IDisposable, IEquatable + internal sealed partial class SNIPacket { - private byte[] _data; - private int _length; - private int _capacity; - private int _offset; - private string _description; - private SNIAsyncCallback _completionCallback; - private bool _isBufferFromArrayPool; - - public SNIPacket() { } - - public SNIPacket(int capacity) + [Flags] + private enum SNIPacketFlags : uint { - Allocate(capacity); + None = 0, + ArrayFromPool = 1, + MuxHeaderReserved = 2, + MuxHeaderWritten = 4, } - /// - /// Packet description (used for debugging) - /// - public string Description - { - get - { - return _description; - } + private int _length; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length + private int _capacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length + private int _offset; // the start point of the data in the data segment, advanced by Take-ing data + private int _header; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader + private SNIPacketFlags _flags; + private byte[] _data; + private SNIAsyncCallback _completionCallback; - set - { - _description = value; - } + public SNIPacket(int capacity, bool reserveMuxHeader=false) + { + Allocate(capacity, reserveMuxHeader); } /// @@ -58,15 +50,11 @@ public string Description /// /// Packet validity /// - public bool IsInvalid => (_data == null); + public bool IsInvalid => _data is null; - /// - /// Packet data - /// - public void Dispose() - { - Release(); - } + public bool MuxHeaderReserved => ((_flags & SNIPacketFlags.MuxHeaderReserved) == SNIPacketFlags.MuxHeaderReserved); + + public bool MuxHeaderWritten => ((_flags & SNIPacketFlags.MuxHeaderWritten) == SNIPacketFlags.MuxHeaderWritten); /// /// Set async completion callback @@ -90,68 +78,54 @@ public void InvokeCompletionCallback(uint sniErrorCode) /// Allocate space for data /// /// Length of byte array to be allocated - public void Allocate(int capacity) + public void Allocate(int capacity, bool reserveMuxHeader) { - if (_data != null && _data.Length < capacity) + SNIPacketFlags flags = reserveMuxHeader ? SNIPacketFlags.MuxHeaderReserved : SNIPacketFlags.None; + int headerCapacity = reserveMuxHeader ? SNISMUXHeader.HEADER_LENGTH : 0; + int totalCapacity = headerCapacity + capacity; + if (_data != null) { - if (_isBufferFromArrayPool) + if (_data.Length < totalCapacity) { - ArrayPool.Shared.Return(_data); + Array.Clear(_data, 0, _header + _length); + if ((_flags & SNIPacketFlags.ArrayFromPool) == SNIPacketFlags.ArrayFromPool) + { + ArrayPool.Shared.Return(_data, clearArray: false); + _flags &= ~SNIPacketFlags.ArrayFromPool; + } + _data = null; + } + else + { + // if the current array is big enough and rented keep it + flags |= (_flags & SNIPacketFlags.ArrayFromPool); } - _data = null; } if (_data == null) { - _data = ArrayPool.Shared.Rent(capacity); - _isBufferFromArrayPool = true; + _data = ArrayPool.Shared.Rent(totalCapacity); + flags |= SNIPacketFlags.ArrayFromPool; // set local not instance because it will be assigned after this block } + _flags = flags; _capacity = capacity; _length = 0; _offset = 0; + _header = headerCapacity; } /// - /// Clone packet - /// - /// Cloned packet - public SNIPacket Clone() - { - SNIPacket packet = new SNIPacket(_capacity); - Buffer.BlockCopy(_data, 0, packet._data, 0, _capacity); - packet._length = _length; - packet._description = _description; - packet._completionCallback = _completionCallback; - - return packet; - } - - /// - /// Get packet data + /// Read packet data into a buffer without removing it from the packet /// /// Buffer - /// Data in packet + /// Number of bytes read from the packet into the buffer public void GetData(byte[] buffer, ref int dataSize) { - Buffer.BlockCopy(_data, 0, buffer, 0, _length); + Buffer.BlockCopy(_data, _header, buffer, 0, _length); // read from dataSize = _length; } - /// - /// Set packet data - /// - /// Data - /// Length - public void SetData(byte[] data, int length) - { - _data = data; - _length = length; - _capacity = data.Length; - _offset = 0; - _isBufferFromArrayPool = false; - } - /// /// Take data from another packet /// @@ -160,7 +134,7 @@ public void SetData(byte[] data, int length) /// Amount of data taken public int TakeData(SNIPacket packet, int size) { - int dataSize = TakeData(packet._data, packet._length, size); + int dataSize = TakeData(packet._data, packet._header + packet._length, size); packet._length += dataSize; return dataSize; } @@ -172,7 +146,7 @@ public int TakeData(SNIPacket packet, int size) /// Size public void AppendData(byte[] data, int size) { - Buffer.BlockCopy(data, 0, _data, _length, size); + Buffer.BlockCopy(data, 0, _data, _header + _length, size); _length += size; } @@ -183,21 +157,11 @@ public void AppendData(ReadOnlySpan data) } /// - /// Append another packet - /// - /// Packet - public void AppendPacket(SNIPacket packet) - { - Buffer.BlockCopy(packet._data, 0, _data, _length, packet._length); - _length += packet._length; - } - - /// - /// Take data from packet and advance offset + /// Read data from the packet into the buffer at dataOffset for zize and then remove that data from the packet /// /// Buffer - /// Data offset - /// Size + /// Data offset to write data at + /// Number of bytes to read from the packet into the buffer /// public int TakeData(byte[] buffer, int dataOffset, int size) { @@ -211,11 +175,31 @@ public int TakeData(byte[] buffer, int dataOffset, int size) size = _length - _offset; } - Buffer.BlockCopy(_data, _offset, buffer, dataOffset, size); + Buffer.BlockCopy(_data, (_header + _offset), buffer, dataOffset, size); _offset += size; return size; } + + /// + /// Set the MARS SMUX header information for this packet + /// + /// the header to write into the packet + public void SetHeader(SNISMUXHeader header) + { + Debug.Assert(header != null, "writing null mux header to packet"); + + Debug.Assert(_offset == 0, "writing mux header to partially read packet"); + Debug.Assert(_header == SNISMUXHeader.HEADER_LENGTH, "writing mux header to partially incorrectly sized reserved region"); + Debug.Assert(((_flags & SNIPacketFlags.MuxHeaderReserved) == SNIPacketFlags.MuxHeaderReserved), "writing mux heaser to non-mux packet"); + + header.Write(_data.AsSpan(0, SNISMUXHeader.HEADER_LENGTH)); + _capacity += _header; + _length += _header; + _header = 0; + _flags |= SNIPacketFlags.MuxHeaderWritten; + } + /// /// Release packet /// @@ -223,24 +207,19 @@ public void Release() { if (_data != null) { - if(_isBufferFromArrayPool) + Array.Clear(_data, 0, _header + _length); + if ((_flags & SNIPacketFlags.ArrayFromPool) == SNIPacketFlags.ArrayFromPool) { - ArrayPool.Shared.Return(_data); + ArrayPool.Shared.Return(_data, clearArray: false); + _flags &= ~SNIPacketFlags.ArrayFromPool; } _data = null; _capacity = 0; } - Reset(); - } - - /// - /// Reset packet - /// - public void Reset() - { _length = 0; _offset = 0; - _description = null; + _header = 0; + _flags = SNIPacketFlags.None; _completionCallback = null; } @@ -250,7 +229,7 @@ public void Reset() /// Stream to read from public void ReadFromStream(Stream stream) { - _length = stream.Read(_data, 0, _capacity); + _length = stream.Read(_data, _header, _capacity); } /// @@ -259,48 +238,10 @@ public void ReadFromStream(Stream stream) /// Stream to write to public void WriteToStream(Stream stream) { - stream.Write(_data, 0, _length); - } - - /// - /// Get hash code - /// - /// Hash code - public override int GetHashCode() - { - return base.GetHashCode(); + stream.Write(_data, _header, _length); } - /// - /// Check packet equality - /// - /// - /// true if equal - public override bool Equals(object obj) - { - SNIPacket packet = obj as SNIPacket; - - if (packet != null) - { - return Equals(packet); - } - - return false; - } + } - /// - /// Check packet equality - /// - /// - /// true if equal - public bool Equals(SNIPacket packet) - { - if (packet != null) - { - return ReferenceEquals(packet, this); - } - return false; - } - } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs index b61f623a7442..38bb17eb13a0 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs @@ -227,18 +227,19 @@ public uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId) /// SNI error status public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync) { - SNIPacket clonedPacket = packet.Clone(); + Debug.Assert(handle is SNIMarsHandle || !packet.MuxHeaderReserved, $"handle type and mux reservation do no match, handle={handle.GetType().Name}, packet.MuxHeaderReserved={packet.MuxHeaderReserved}"); + uint result; if (sync) { - result = handle.Send(clonedPacket); - clonedPacket.Dispose(); + result = handle.Send(packet); + packet.Release(); } else { - result = handle.SendAsync(clonedPacket, true); + result = handle.SendAsync(packet, true); } - + return result; } @@ -406,8 +407,6 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, objec return new SNITCPHandle(hostName, port, timerExpire, callbackObject, parallel); } - - /// /// Creates an SNINpHandle object /// @@ -446,7 +445,7 @@ public uint ReadAsync(SNIHandle handle, out SNIPacket packet) /// Length public void PacketSetData(SNIPacket packet, byte[] data, int length) { - packet.SetData(data, length); + packet.AppendData(data, length); } /// diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs index 3618b3fc8730..96c47fffecce 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs @@ -19,7 +19,7 @@ namespace System.Data.SqlClient.SNI /// /// TCP connection handle /// - internal class SNITCPHandle : SNIHandle + internal sealed class SNITCPHandle : SNIHandle { private readonly string _targetServer; private readonly object _callbackObject; diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SslOverTdsStream.cs index cb274689ff4f..9a9bd8292bfb 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Buffers; using System.IO; using System.IO.Pipes; using System.Threading; @@ -88,44 +89,71 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel /// private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) { - int readBytes = 0; - byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count]; - if (_encapsulate) { - if (_packetBytes == 0) - { - // Account for split packets - while (readBytes < TdsEnums.HEADER_LEN) - { - readBytes += async ? - await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) : - _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes); - } + return await ReadInternalEncapsulate(buffer, offset, count, token, async); + } + else if (async) + { + return await ReadInternalAsync(buffer, offset, count, token); + } + else + { + return ReadInternalSync(buffer, offset, count); + } + } - _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; - _packetBytes -= TdsEnums.HEADER_LEN; - } + private async Task ReadInternalEncapsulate(byte[] buffer, int offset, int count, CancellationToken token, bool async) + { + int readBytes = 0; + byte[] packetData = ArrayPool.Shared.Rent(count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count); - if (count > _packetBytes) + if (_packetBytes == 0) + { + // Account for split packets + while (readBytes < TdsEnums.HEADER_LEN) { - count = _packetBytes; + readBytes += (async ? + await ReadInternalAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token) : + ReadInternalSync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes) + ); } - } - readBytes = async ? - await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) : - _stream.Read(packetData, 0, count); + _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; + _packetBytes -= TdsEnums.HEADER_LEN; + } - if (_encapsulate) + if (count > _packetBytes) { - _packetBytes -= readBytes; + count = _packetBytes; } + + readBytes = (async ? + await ReadInternalAsync(packetData, 0, count, token) : + ReadInternalSync(packetData, 0, count) + ); + + _packetBytes -= readBytes; + Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); + + Array.Clear(packetData, 0, readBytes); + ArrayPool.Shared.Return(packetData, clearArray: false); + return readBytes; } + private async Task ReadInternalAsync(byte[] buffer, int offset, int count, CancellationToken token) + { + return await _stream.ReadAsync(buffer, 0, count, token).ConfigureAwait(false); + } + + private int ReadInternalSync(byte[] buffer, int offset, int count) + { + return _stream.Read(buffer, 0, count); + } + /// /// The internal write method calls Sync APIs when Async flag is false /// @@ -153,11 +181,13 @@ private async Task WriteInternal(byte[] buffer, int offset, int count, Cancellat count -= currentCount; // Prepend buffer data with TDS prelogin header - byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount]; + int combinedLength = TdsEnums.HEADER_LEN + currentCount; + byte[] combinedBuffer = ArrayPool.Shared.Rent(combinedLength); // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a // partial packet (whether or not count != 0). // + combinedBuffer[7] = 0; // touch this first for the jit bounds check combinedBuffer[0] = PRELOGIN_PACKET_TYPE; combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); @@ -165,21 +195,21 @@ private async Task WriteInternal(byte[] buffer, int offset, int count, Cancellat combinedBuffer[4] = 0; combinedBuffer[5] = 0; combinedBuffer[6] = 0; - combinedBuffer[7] = 0; - for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++) - { - combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)]; - } + Array.Copy(buffer, currentOffset, combinedBuffer, TdsEnums.HEADER_LEN, (combinedLength - TdsEnums.HEADER_LEN)); if (async) { - await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false); + await _stream.WriteAsync(combinedBuffer, 0, combinedLength, token).ConfigureAwait(false); } else { - _stream.Write(combinedBuffer, 0, combinedBuffer.Length); + _stream.Write(combinedBuffer, 0, combinedLength); } + + Array.Clear(combinedBuffer, 0, combinedLength); + ArrayPool.Shared.Return(combinedBuffer); + } else { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs index 6e527a828ec5..87ba7c9b75b1 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs @@ -775,7 +775,7 @@ private void ResetCancelAndProcessAttention() protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize); - internal abstract PacketHandle GetResetWritePacket(); + internal abstract PacketHandle GetResetWritePacket(int dataSize); internal abstract void ClearAllWritePackets(); @@ -3420,7 +3420,7 @@ internal void SendAttention(bool mustTakeWriteLock = false) private Task WriteSni(bool canAccumulate) { // Prepare packet, and write to packet. - PacketHandle packet = GetResetWritePacket(); + PacketHandle packet = GetResetWritePacket(_outBytesUsed); SetBufferSecureStrings(); SetPacketData(packet, _outBuff, _outBytesUsed); diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs index fb57c55f556e..fc1d1587c90e 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs @@ -19,14 +19,14 @@ internal sealed class TdsParserStateObjectFactory //private static bool shouldUseLegacyNetorking; //public static bool UseManagedSNI { get; } = AppContext.TryGetSwitch(UseLegacyNetworkingOnWindows, out shouldUseLegacyNetorking) ? !shouldUseLegacyNetorking : true; -#if DEBUG +//#if DEBUG private static Lazy useManagedSNIOnWindows = new Lazy( () => bool.TrueString.Equals(Environment.GetEnvironmentVariable("System.Data.SqlClient.UseManagedSNIOnWindows"), StringComparison.InvariantCultureIgnoreCase) ); public static bool UseManagedSNI => useManagedSNIOnWindows.Value; -#else - public static bool UseManagedSNI { get; } = false; -#endif +//#else +// public static bool UseManagedSNI { get; } = false; +//#endif public EncryptionOptions EncryptionOptions { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs index 151d4e554aa0..68468025d29c 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -11,17 +11,13 @@ namespace System.Data.SqlClient.SNI { internal class TdsParserStateObjectManaged : TdsParserStateObject { - private SNIMarsConnection _marsConnection = null; - private SNIHandle _sessionHandle = null; // the SNI handle we're to work on - private SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS - internal SNIPacket _sniAsyncAttnPacket = null; // Packet to use to send Attn - private readonly Dictionary _pendingWritePackets = new Dictionary(); // Stores write packets that have been sent to SNI, but have not yet finished writing (i.e. we are waiting for SNI's callback) - - private readonly WritePacketCache _writePacketCache = new WritePacketCache(); // Store write packets that are ready to be re-used + private SNIMarsConnection _marsConnection; + private SNIHandle _sessionHandle; + private SspiClientContextStatus _sspiClientContextStatus; public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { } - internal SspiClientContextStatus sspiClientContextStatus = new SspiClientContextStatus(); + internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject physicalConnection, bool async) : base(parser, physicalConnection, async) @@ -81,27 +77,17 @@ protected override void RemovePacketFromPendingList(PacketHandle packet) internal override void Dispose() { - SNIPacket packetHandle = _sniPacket; SNIHandle sessionHandle = _sessionHandle; - SNIPacket asyncAttnPacket = _sniAsyncAttnPacket; - _sniPacket = null; _sessionHandle = null; - _sniAsyncAttnPacket = null; _marsConnection = null; DisposeCounters(); - if (null != sessionHandle || null != packetHandle) + if (sessionHandle != null) { - packetHandle?.Dispose(); - asyncAttnPacket?.Dispose(); - - if (sessionHandle != null) - { - sessionHandle.Dispose(); - DecrementPendingCallbacks(true); // Will dispose of GC handle. - } + sessionHandle.Dispose(); + DecrementPendingCallbacks(true); // Will dispose of GC handle. } DisposePacketCache(); @@ -109,11 +95,7 @@ internal override void Dispose() internal override void DisposePacketCache() { - lock (_writePacketLockObject) - { - _writePacketCache.Dispose(); - // Do not set _writePacketCache to null, just in case a WriteAsyncCallback completes after this point - } + } protected override void FreeGcHandle(int remaining, bool release) @@ -144,7 +126,7 @@ internal override bool IsPacketEmpty(PacketHandle packet) internal override void ReleasePacket(PacketHandle syncReadPacket) { - syncReadPacket.ManagedPacket?.Dispose(); + syncReadPacket.ManagedPacket?.Release(); } internal override uint CheckConnection() @@ -162,13 +144,17 @@ internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) internal override PacketHandle CreateAndSetAttentionPacket() { - if (_sniAsyncAttnPacket == null) - { - SNIPacket attnPacket = new SNIPacket(); - SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); - _sniAsyncAttnPacket = attnPacket; - } - return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket); + //if (_sniAsyncAttnPacket == null) + //{ + // SNIPacket attnPacket = new SNIPacket(); + // SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); + // _sniAsyncAttnPacket = attnPacket; + //} + //return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket); + + PacketHandle packetHandle = GetResetWritePacket(TdsEnums.HEADER_LEN); + SetPacketData(packetHandle, SQL.AttentionHeader, TdsEnums.HEADER_LEN); + return packetHandle; } internal override uint WritePacket(PacketHandle packet, bool sync) @@ -192,34 +178,16 @@ internal override bool IsValidPacket(PacketHandle packet) ); } - internal override PacketHandle GetResetWritePacket() + internal override PacketHandle GetResetWritePacket(int dataSize) { - if (_sniPacket != null) - { - _sniPacket.Reset(); - } - else - { - lock (_writePacketLockObject) - { - _sniPacket = _writePacketCache.Take(Handle); - } - } - return PacketHandle.FromManagedPacket(_sniPacket); + var packet = new SNIPacket(dataSize, _sessionHandle.SMUXEnabled); + Debug.Assert(packet.MuxHeaderReserved == _sessionHandle.SMUXEnabled, "failed to reserve mux header"); + return PacketHandle.FromManagedPacket(packet); } internal override void ClearAllWritePackets() { - if (_sniPacket != null) - { - _sniPacket.Dispose(); - _sniPacket = null; - } - lock (_writePacketLockObject) - { - Debug.Assert(_pendingWritePackets.Count == 0 && _asyncWriteCount == 0, "Should not clear all write packets if there are packets pending"); - _writePacketCache.Clear(); - } + Debug.Assert(_asyncWriteCount == 0, "Should not clear all write packets if there are packets pending"); } internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData(packet.ManagedPacket, buffer, bytesUsed); @@ -239,76 +207,21 @@ internal override uint EnableMars(ref uint info) return TdsEnums.SNI_ERROR; } - internal override uint EnableSsl(ref uint info)=> SNIProxy.Singleton.EnableSsl(Handle, info); + internal override uint EnableSsl(ref uint info) => SNIProxy.Singleton.EnableSsl(Handle, info); internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.Singleton.SetConnectionBufferSize(Handle, unsignedPacketSize); internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer) { - SNIProxy.Singleton.GenSspiClientContext(sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer); + if (_sspiClientContextStatus==null) + { + _sspiClientContextStatus = new SspiClientContextStatus(); + } + SNIProxy.Singleton.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer); sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0); return 0; } internal override uint WaitForSSLHandShakeToComplete() => 0; - - internal sealed class WritePacketCache : IDisposable - { - private bool _disposed; - private Stack _packets; - - public WritePacketCache() - { - _disposed = false; - _packets = new Stack(); - } - - public SNIPacket Take(SNIHandle sniHandle) - { - SNIPacket packet; - if (_packets.Count > 0) - { - // Success - reset the packet - packet = _packets.Pop(); - packet.Reset(); - } - else - { - // Failed to take a packet - create a new one - packet = new SNIPacket(); - } - return packet; - } - - public void Add(SNIPacket packet) - { - if (!_disposed) - { - _packets.Push(packet); - } - else - { - // If we're disposed, then get rid of any packets added to us - packet.Dispose(); - } - } - - public void Clear() - { - while (_packets.Count > 0) - { - _packets.Pop().Dispose(); - } - } - - public void Dispose() - { - if (!_disposed) - { - _disposed = true; - Clear(); - } - } - } } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs index 5c43bdb07902..f50961698b4c 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs @@ -254,7 +254,7 @@ internal override bool IsValidPacket(PacketHandle packetPointer) ); } - internal override PacketHandle GetResetWritePacket() + internal override PacketHandle GetResetWritePacket(int dataSize) { if (_sniPacket != null) { From 5aea9de2ff0733bf148e4ae08f369acb1092bbee Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Tue, 26 Mar 2019 00:53:05 +0000 Subject: [PATCH 2/4] address feedback --- .../src/System/Data/SqlClient/SNI/SNIPacket.cs | 8 ++++---- .../System/Data/SqlClient/TdsParserStateObjectManaged.cs | 8 -------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs index 8bacb0bf1d34..e9da6e54b993 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs @@ -32,8 +32,8 @@ private enum SNIPacketFlags : uint private byte[] _data; private SNIAsyncCallback _completionCallback; - public SNIPacket(int capacity, bool reserveMuxHeader=false) - { + public SNIPacket(int capacity, bool reserveMuxHeader = false) + { Allocate(capacity, reserveMuxHeader); } @@ -98,7 +98,7 @@ public void Allocate(int capacity, bool reserveMuxHeader) else { // if the current array is big enough and rented keep it - flags |= (_flags & SNIPacketFlags.ArrayFromPool); + flags |= (_flags & SNIPacketFlags.ArrayFromPool); } } @@ -175,7 +175,7 @@ public int TakeData(byte[] buffer, int dataOffset, int size) size = _length - _offset; } - Buffer.BlockCopy(_data, (_header + _offset), buffer, dataOffset, size); + Buffer.BlockCopy(_data, _header + _offset, buffer, dataOffset, size); _offset += size; return size; } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs index 68468025d29c..858bd7abbff1 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -144,14 +144,6 @@ internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) internal override PacketHandle CreateAndSetAttentionPacket() { - //if (_sniAsyncAttnPacket == null) - //{ - // SNIPacket attnPacket = new SNIPacket(); - // SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); - // _sniAsyncAttnPacket = attnPacket; - //} - //return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket); - PacketHandle packetHandle = GetResetWritePacket(TdsEnums.HEADER_LEN); SetPacketData(packetHandle, SQL.AttentionHeader, TdsEnums.HEADER_LEN); return packetHandle; From d0ee9127723431e5f19e9bb1302aeafeeab2b2ac Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Wed, 27 Mar 2019 20:18:05 +0000 Subject: [PATCH 3/4] remove debug ifdef --- .../Data/SqlClient/TdsParserStateObjectFactory.Windows.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs index fc1d1587c90e..fb57c55f556e 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs @@ -19,14 +19,14 @@ internal sealed class TdsParserStateObjectFactory //private static bool shouldUseLegacyNetorking; //public static bool UseManagedSNI { get; } = AppContext.TryGetSwitch(UseLegacyNetworkingOnWindows, out shouldUseLegacyNetorking) ? !shouldUseLegacyNetorking : true; -//#if DEBUG +#if DEBUG private static Lazy useManagedSNIOnWindows = new Lazy( () => bool.TrueString.Equals(Environment.GetEnvironmentVariable("System.Data.SqlClient.UseManagedSNIOnWindows"), StringComparison.InvariantCultureIgnoreCase) ); public static bool UseManagedSNI => useManagedSNIOnWindows.Value; -//#else -// public static bool UseManagedSNI { get; } = false; -//#endif +#else + public static bool UseManagedSNI { get; } = false; +#endif public EncryptionOptions EncryptionOptions { From d64696bda4068e87d93af8eeb2de7459b8d30fa3 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Sat, 30 Mar 2019 12:10:40 +0000 Subject: [PATCH 4/4] address feedback --- .../System/Data/SqlClient/SNI/SNIHandle.cs | 2 +- .../Data/SqlClient/SNI/SNIMarsConnection.cs | 2 +- .../Data/SqlClient/SNI/SNIMarsHandle.cs | 18 ++- .../System/Data/SqlClient/SNI/SNINpHandle.cs | 12 +- .../SqlClient/SNI/SNIPacket.NetCoreApp.cs | 12 +- .../SqlClient/SNI/SNIPacket.NetStandard.cs | 12 +- .../System/Data/SqlClient/SNI/SNIPacket.cs | 153 ++++++------------ .../src/System/Data/SqlClient/SNI/SNIProxy.cs | 2 - .../System/Data/SqlClient/SNI/SNITcpHandle.cs | 10 +- .../SqlClient/TdsParserStateObjectManaged.cs | 8 +- 10 files changed, 86 insertions(+), 145 deletions(-) diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs index b9370fa7a981..00af92b6df74 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIHandle.cs @@ -84,7 +84,7 @@ internal abstract class SNIHandle /// public abstract Guid ConnectionId { get; } - public virtual bool SMUXEnabled => false; + public virtual int ReserveHeaderSize => 0; #if DEBUG /// diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs index 79a17795b8da..fe513b5c2c00 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsConnection.cs @@ -202,7 +202,7 @@ public void HandleReceiveComplete(SNIPacket packet, uint sniErrorCode) _currentHeader.Read(_headerBytes); _dataBytesLeft = (int)_currentHeader.length; - _currentPacket = new SNIPacket((int)_currentHeader.length); + _currentPacket = new SNIPacket(headerSize: 0, dataSize: (int)_currentHeader.length); } currentHeader = _currentHeader; diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs index 0d767d443dab..4d96adf4ff25 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -37,7 +37,7 @@ internal sealed class SNIMarsHandle : SNIHandle public override uint Status => _status; - public override bool SMUXEnabled => true; + public override int ReserveHeaderSize => SNISMUXHeader.HEADER_LENGTH; /// /// Dispose object @@ -77,11 +77,12 @@ public SNIMarsHandle(SNIMarsConnection connection, ushort sessionId, object call /// SMUX header flags private void SendControlPacket(SNISMUXFlags flags) { - SNIPacket packet = new SNIPacket(0,reserveMuxHeader:true); + SNIPacket packet = new SNIPacket(headerSize: SNISMUXHeader.HEADER_LENGTH, dataSize: 0); lock (this) { SetupSMUXHeader(0, flags); - packet.SetHeader(_currentHeader); + _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); } _connection.Send(packet); } @@ -105,10 +106,11 @@ private void SetupSMUXHeader(int length, SNISMUXFlags flags) /// The packet with the SMUx header set. private SNIPacket SetPacketSMUXHeader(SNIPacket packet) { - Debug.Assert(packet.MuxHeaderReserved, "attempting to mux packet without mux reservation"); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to mux packet without mux reservation"); SetupSMUXHeader(packet.Length, SNISMUXFlags.SMUX_DATA); - packet.SetHeader(_currentHeader); + _currentHeader.Write(packet.GetHeaderBuffer(SNISMUXHeader.HEADER_LENGTH)); + packet.SetHeaderActive(); return packet; } @@ -119,7 +121,7 @@ private SNIPacket SetPacketSMUXHeader(SNIPacket packet) /// SNI error code public override uint Send(SNIPacket packet) { - Debug.Assert(packet.MuxHeaderReserved, "attempting to send muxed packet without mux reservation in Send"); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without mux reservation in Send"); while (true) { @@ -155,7 +157,7 @@ public override uint Send(SNIPacket packet) /// SNI error code private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) { - Debug.Assert(packet.MuxHeaderReserved, "attempting to send muxed packet without mux reservation in InternalSendAsync"); + Debug.Assert(packet.ReservedHeaderSize == SNISMUXHeader.HEADER_LENGTH, "mars handle attempting to send muxed packet without mux reservation in InternalSendAsync"); lock (this) { if (_sequenceNumber >= _sendHighwater) @@ -164,7 +166,7 @@ private uint InternalSendAsync(SNIPacket packet, SNIAsyncCallback callback) } SNIPacket muxedPacket = SetPacketSMUXHeader(packet); - muxedPacket.SetCompletionCallback(callback??HandleSendComplete); + muxedPacket.SetCompletionCallback(callback ?? HandleSendComplete); return _connection.SendAsync(muxedPacket, callback); } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs index 526289f2c01f..381d3ec8d92d 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNINpHandle.cs @@ -22,7 +22,7 @@ internal sealed class SNINpHandle : SNIHandle private readonly string _targetServer; private readonly object _callbackObject; - + private Stream _stream; private NamedPipeClientStream _pipeStream; private SslOverTdsStream _sslOverTdsStream; @@ -61,13 +61,13 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object _pipeStream.Connect((int)ts.TotalMilliseconds); } } - catch(TimeoutException te) + catch (TimeoutException te) { SNICommon.ReportSNIError(SNIProviders.NP_PROV, SNICommon.ConnOpenFailedError, te); _status = TdsEnums.SNI_ERROR; return; } - catch(IOException ioe) + catch (IOException ioe) { SNICommon.ReportSNIError(SNIProviders.NP_PROV, SNICommon.ConnOpenFailedError, ioe); _status = TdsEnums.SNI_ERROR; @@ -150,7 +150,7 @@ public override uint Receive(out SNIPacket packet, int timeout) packet = null; try { - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); packet.ReadFromStream(_stream); if (packet.Length == 0) @@ -174,8 +174,8 @@ public override uint Receive(out SNIPacket packet, int timeout) public override uint ReceiveAsync(ref SNIPacket packet) { - packet = new SNIPacket(_bufferSize); - + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); + try { packet.ReadFromStreamAsync(_stream, _receiveCallback); diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs index 7e149c77b267..5fd9e445996d 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs @@ -24,8 +24,8 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< bool error = false; try { - packet._length = await valueTask.ConfigureAwait(false); - if (packet._length == 0) + packet._dataLength = await valueTask.ConfigureAwait(false); + if (packet._dataLength == 0) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); error = true; @@ -45,13 +45,13 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - ValueTask vt = stream.ReadAsync(new Memory(_data, _header, _capacity), CancellationToken.None); + ValueTask vt = stream.ReadAsync(new Memory(_data, _headerLength, _dataCapacity), CancellationToken.None); if (vt.IsCompletedSuccessfully) { - _length = vt.Result; + _dataLength = vt.Result; // Zero length to go via async local function as is error condition - if (_length > 0) + if (_dataLength > 0) { callback(this, TdsEnums.SNI_SUCCESS); @@ -92,7 +92,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider } } - ValueTask vt = stream.WriteAsync(new Memory(_data, _header, _length), CancellationToken.None); + ValueTask vt = stream.WriteAsync(new Memory(_data, _headerLength, _dataLength), CancellationToken.None); if (vt.IsCompletedSuccessfully) { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs index 6906ff8d8638..ce03faf5d7a9 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.NetStandard.cs @@ -24,8 +24,8 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task bool error = false; try { - packet._length = await task.ConfigureAwait(false); - if (packet._length == 0) + packet._dataLength = await task.ConfigureAwait(false); + if (packet._dataLength == 0) { SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, 0, SNICommon.ConnTerminatedError, string.Empty); error = true; @@ -45,13 +45,13 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - Task t = stream.ReadAsync(_data, _header, _capacity, CancellationToken.None); + Task t = stream.ReadAsync(_data, _headerLength, _dataCapacity, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { - _length = t.Result; + _dataLength = t.Result; // Zero length to go via async local function as is error condition - if (_length > 0) + if (_dataLength > 0) { callback(this, TdsEnums.SNI_SUCCESS); @@ -92,7 +92,7 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider } } - Task t = stream.WriteAsync(_data, _header, _length, CancellationToken.None); + Task t = stream.WriteAsync(_data, _headerLength, _dataLength, CancellationToken.None); if ((t.Status & TaskStatus.RanToCompletion) != 0) { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs index e9da6e54b993..2979f64bbbc9 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs @@ -10,51 +10,37 @@ namespace System.Data.SqlClient.SNI { - /// - /// SNI Packet - /// internal sealed partial class SNIPacket { - [Flags] - private enum SNIPacketFlags : uint - { - None = 0, - ArrayFromPool = 1, - MuxHeaderReserved = 2, - MuxHeaderWritten = 4, - } - - private int _length; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length - private int _capacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length - private int _offset; // the start point of the data in the data segment, advanced by Take-ing data - private int _header; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader - private SNIPacketFlags _flags; + private int _dataLength; // the length of the data in the data segment, advanced by Append-ing data, does not include smux header length + private int _dataCapacity; // the total capacity requested, if the array is rented this may be less than the _data.Length, does not include smux header length + private int _dataOffset; // the start point of the data in the data segment, advanced by Take-ing data + private int _headerLength; // the amount of space at the start of the array reserved for the smux header, this is zeroed in SetHeader + // _headerOffset is not needed because it is always 0 private byte[] _data; private SNIAsyncCallback _completionCallback; - public SNIPacket(int capacity, bool reserveMuxHeader = false) + public SNIPacket(int headerSize, int dataSize) { - Allocate(capacity, reserveMuxHeader); + Allocate(headerSize, dataSize); } /// /// Length of data left to process /// - public int DataLeft => (_length - _offset); + public int DataLeft => (_dataLength - _dataOffset); /// /// Length of data /// - public int Length => _length; + public int Length => _dataLength; /// /// Packet validity /// public bool IsInvalid => _data is null; - public bool MuxHeaderReserved => ((_flags & SNIPacketFlags.MuxHeaderReserved) == SNIPacketFlags.MuxHeaderReserved); - - public bool MuxHeaderWritten => ((_flags & SNIPacketFlags.MuxHeaderWritten) == SNIPacketFlags.MuxHeaderWritten); + public int ReservedHeaderSize => _headerLength; /// /// Set async completion callback @@ -77,42 +63,14 @@ public void InvokeCompletionCallback(uint sniErrorCode) /// /// Allocate space for data /// - /// Length of byte array to be allocated - public void Allocate(int capacity, bool reserveMuxHeader) + /// Length of byte array to be allocated + private void Allocate(int headerLength, int dataLength) { - SNIPacketFlags flags = reserveMuxHeader ? SNIPacketFlags.MuxHeaderReserved : SNIPacketFlags.None; - int headerCapacity = reserveMuxHeader ? SNISMUXHeader.HEADER_LENGTH : 0; - int totalCapacity = headerCapacity + capacity; - if (_data != null) - { - if (_data.Length < totalCapacity) - { - Array.Clear(_data, 0, _header + _length); - if ((_flags & SNIPacketFlags.ArrayFromPool) == SNIPacketFlags.ArrayFromPool) - { - ArrayPool.Shared.Return(_data, clearArray: false); - _flags &= ~SNIPacketFlags.ArrayFromPool; - } - _data = null; - } - else - { - // if the current array is big enough and rented keep it - flags |= (_flags & SNIPacketFlags.ArrayFromPool); - } - } - - if (_data == null) - { - _data = ArrayPool.Shared.Rent(totalCapacity); - flags |= SNIPacketFlags.ArrayFromPool; // set local not instance because it will be assigned after this block - } - - _flags = flags; - _capacity = capacity; - _length = 0; - _offset = 0; - _header = headerCapacity; + _data = ArrayPool.Shared.Rent(headerLength + dataLength); + _dataCapacity = dataLength; + _dataLength = 0; + _dataOffset = 0; + _headerLength = headerLength; } /// @@ -122,8 +80,8 @@ public void Allocate(int capacity, bool reserveMuxHeader) /// Number of bytes read from the packet into the buffer public void GetData(byte[] buffer, ref int dataSize) { - Buffer.BlockCopy(_data, _header, buffer, 0, _length); // read from - dataSize = _length; + Buffer.BlockCopy(_data, _headerLength, buffer, 0, _dataLength); // read from + dataSize = _dataLength; } /// @@ -134,8 +92,8 @@ public void GetData(byte[] buffer, ref int dataSize) /// Amount of data taken public int TakeData(SNIPacket packet, int size) { - int dataSize = TakeData(packet._data, packet._header + packet._length, size); - packet._length += dataSize; + int dataSize = TakeData(packet._data, packet._headerLength + packet._dataLength, size); + packet._dataLength += dataSize; return dataSize; } @@ -146,14 +104,8 @@ public int TakeData(SNIPacket packet, int size) /// Size public void AppendData(byte[] data, int size) { - Buffer.BlockCopy(data, 0, _data, _header + _length, size); - _length += size; - } - - public void AppendData(ReadOnlySpan data) - { - data.CopyTo(_data.AsSpan(_length)); - _length += data.Length; + Buffer.BlockCopy(data, 0, _data, _headerLength + _dataLength, size); + _dataLength += size; } /// @@ -165,39 +117,35 @@ public void AppendData(ReadOnlySpan data) /// public int TakeData(byte[] buffer, int dataOffset, int size) { - if (_offset >= _length) + if (_dataOffset >= _dataLength) { return 0; } - if (_offset + size > _length) + if (_dataOffset + size > _dataLength) { - size = _length - _offset; + size = _dataLength - _dataOffset; } - Buffer.BlockCopy(_data, _header + _offset, buffer, dataOffset, size); - _offset += size; + Buffer.BlockCopy(_data, _headerLength + _dataOffset, buffer, dataOffset, size); + _dataOffset += size; return size; } - - /// - /// Set the MARS SMUX header information for this packet - /// - /// the header to write into the packet - public void SetHeader(SNISMUXHeader header) + public Span GetHeaderBuffer(int headerSize) { - Debug.Assert(header != null, "writing null mux header to packet"); - - Debug.Assert(_offset == 0, "writing mux header to partially read packet"); - Debug.Assert(_header == SNISMUXHeader.HEADER_LENGTH, "writing mux header to partially incorrectly sized reserved region"); - Debug.Assert(((_flags & SNIPacketFlags.MuxHeaderReserved) == SNIPacketFlags.MuxHeaderReserved), "writing mux heaser to non-mux packet"); + Debug.Assert(_dataOffset == 0, "requested packet header buffer from partially consumed packet"); + Debug.Assert(headerSize > 0, "requested packet header buffer of 0 length"); + Debug.Assert(_headerLength == headerSize, "requested packet header of headerSize which is not equal to the _headerSize reservation"); + return _data.AsSpan(0, headerSize); + } - header.Write(_data.AsSpan(0, SNISMUXHeader.HEADER_LENGTH)); - _capacity += _header; - _length += _header; - _header = 0; - _flags |= SNIPacketFlags.MuxHeaderWritten; + public void SetHeaderActive() + { + Debug.Assert(_headerLength > 0, "requested to set header active when it is not reserved or is already active"); + _dataCapacity += _headerLength; + _dataLength += _headerLength; + _headerLength = 0; } /// @@ -207,19 +155,14 @@ public void Release() { if (_data != null) { - Array.Clear(_data, 0, _header + _length); - if ((_flags & SNIPacketFlags.ArrayFromPool) == SNIPacketFlags.ArrayFromPool) - { - ArrayPool.Shared.Return(_data, clearArray: false); - _flags &= ~SNIPacketFlags.ArrayFromPool; - } + Array.Clear(_data, 0, _headerLength + _dataLength); + ArrayPool.Shared.Return(_data, clearArray: false); _data = null; - _capacity = 0; + _dataCapacity = 0; } - _length = 0; - _offset = 0; - _header = 0; - _flags = SNIPacketFlags.None; + _dataLength = 0; + _dataOffset = 0; + _headerLength = 0; _completionCallback = null; } @@ -229,7 +172,7 @@ public void Release() /// Stream to read from public void ReadFromStream(Stream stream) { - _length = stream.Read(_data, _header, _capacity); + _dataLength = stream.Read(_data, _headerLength, _dataCapacity); } /// @@ -238,7 +181,7 @@ public void ReadFromStream(Stream stream) /// Stream to write to public void WriteToStream(Stream stream) { - stream.Write(_data, _header, _length); + stream.Write(_data, _headerLength, _dataLength); } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs index e3f47dbaf898..8bb3b37b0e26 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIProxy.cs @@ -220,8 +220,6 @@ public uint GetConnectionId(SNIHandle handle, ref Guid clientConnectionId) /// SNI error status public uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync) { - Debug.Assert(handle is SNIMarsHandle || !packet.MuxHeaderReserved, $"handle type and mux reservation do no match, handle={handle.GetType().Name}, packet.MuxHeaderReserved={packet.MuxHeaderReserved}"); - uint result; if (sync) { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs index 96c47fffecce..805c98b8649d 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs @@ -25,7 +25,7 @@ internal sealed class SNITCPHandle : SNIHandle private readonly object _callbackObject; private readonly Socket _socket; private NetworkStream _tcpStream; - + private Stream _stream; private SslStream _sslStream; private SslOverTdsStream _sslOverTdsStream; @@ -144,7 +144,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba { _socket = Connect(serverName, port, ts); } - + if (_socket == null || !_socket.Connected) { if (_socket != null) @@ -224,7 +224,7 @@ void Cancel() { sockets[i] = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp); // enable keep-alive on socket - SNITcpHandle.SetKeepAliveValues(ref sockets[i]); + SNITcpHandle.SetKeepAliveValues(ref sockets[i]); sockets[i].Connect(ipAddresses[i], port); if (sockets[i] != null) // sockets[i] can be null if cancel callback is executed during connect() { @@ -470,7 +470,7 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds) return TdsEnums.SNI_WAIT_TIMEOUT; } - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); packet.ReadFromStream(_stream); if (packet.Length == 0) @@ -540,7 +540,7 @@ public override uint SendAsync(SNIPacket packet, bool disposePacketAfterSendAsyn /// SNI error code public override uint ReceiveAsync(ref SNIPacket packet) { - packet = new SNIPacket(_bufferSize); + packet = new SNIPacket(headerSize: 0, dataSize: _bufferSize); try { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs index 858bd7abbff1..8ebe1246f6f3 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -17,8 +17,6 @@ internal class TdsParserStateObjectManaged : TdsParserStateObject public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { } - - internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject physicalConnection, bool async) : base(parser, physicalConnection, async) { } @@ -172,8 +170,8 @@ internal override bool IsValidPacket(PacketHandle packet) internal override PacketHandle GetResetWritePacket(int dataSize) { - var packet = new SNIPacket(dataSize, _sessionHandle.SMUXEnabled); - Debug.Assert(packet.MuxHeaderReserved == _sessionHandle.SMUXEnabled, "failed to reserve mux header"); + var packet = new SNIPacket(headerSize: _sessionHandle.ReserveHeaderSize, dataSize: dataSize); + Debug.Assert(packet.ReservedHeaderSize == _sessionHandle.ReserveHeaderSize, "failed to reserve header"); return PacketHandle.FromManagedPacket(packet); } @@ -205,7 +203,7 @@ internal override uint EnableMars(ref uint info) internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer) { - if (_sspiClientContextStatus==null) + if (_sspiClientContextStatus == null) { _sspiClientContextStatus = new SspiClientContextStatus(); }