Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an IBufferWriter<byte> to write the outgoing SSPI blob #2452

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down Expand Up @@ -471,17 +471,20 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
{
sendLength = (uint)outBuff.Length;

fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
OutBuff,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor adjustment here to keep the files in alphabetically order.

<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ internal class SNIProxy
/// </summary>
/// <param name="sspiClientContextStatus">SSPI client context status</param>
/// <param name="receivedBuff">Receive buffer</param>
/// <param name="sendBuff">Send buffer</param>
/// <param name="sendWriter">Writer for send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
{
// TODO: this should use ReadOnlyMemory all the way through
byte[] array = null;
Expand All @@ -46,10 +46,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
receivedBuff.CopyTo(array);
}

GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName);
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
}

private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -103,10 +103,9 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
outSecurityBuffer.token = null;
}

sendBuff = outSecurityBuffer.token;
if (sendBuff == null)
if (outSecurityBuffer.token is { } token)
{
sendBuff = Array.Empty<byte>();
sendWriter.Write(token);
}

sspiClientContextStatus.SecurityContext = securityContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8123,8 +8123,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -8292,8 +8291,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -8348,7 +8347,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlNotificationRequest.cs">
<Link>Microsoft\Data\Sql\SqlNotificationRequest.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note here.

<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
<Link>Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf,
private static unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] ReadOnlySpan<byte> pIn,
[In, Out] byte[] pOut,
[In, Out] Span<byte> pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand All @@ -899,15 +899,16 @@ private static unsafe uint SNISecGenClientContextWrapper(
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)
{
fixed (byte* pInPtr = pIn)
fixed (byte* pOutPtr = pOut)
{
switch (s_architecture)
{
case System.Runtime.InteropServices.Architecture.Arm64:
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X64:
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X86:
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
default:
throw ADP.SNIPlatformNotSupported(s_architecture.ToString());
}
Expand Down Expand Up @@ -1380,15 +1381,17 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
{
sendLength = (uint)outBuff.Length;

fixed (byte* pin_serverUserName = &serverUserName[0])
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
inBuff,
OutBuff,
outBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8922,8 +8922,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -9094,8 +9093,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -9154,7 +9153,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,28 @@ internal void WriteByte(byte b)
// set byte in buffer and increment the counter for number of bytes used in the out buffer
_outBuff[_outBytesUsed++] = b;
}
internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
}

internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
}

//
// Takes a span or a byte array and writes it to the buffer
// If you pass in a span and a null array then the span wil be used.
// If you pass in a non-null array then the array will be used and the span is ignored.
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
// new heap allocated array that will used to callback into the method to continue the write operation.
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
{
if (array != null)
{
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
}
try
{
TdsParser.ReliabilitySection.Assert("unreliable call to WriteByteArray"); // you need to setup for a thread abort somewhere before you call this method
Expand All @@ -949,7 +968,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu

int offset = offsetBuffer;

Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");
Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");

// loop through and write the entire array
do
Expand All @@ -963,12 +982,17 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
int remainder = _outBuff.Length - _outBytesUsed;

// write the remainder
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

// handle counters
offset += remainder;
_outBytesUsed += remainder;
len -= remainder;
b = b.Slice(remainder, len);

Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);

Expand All @@ -981,18 +1005,35 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
completion = new TaskCompletionSource<object>();
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
}
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);

if (array == null)
{
byte[] tempArray = new byte[len];
Span<byte> copyTempTo = tempArray.AsSpan();

Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");

b.CopyTo(copyTempTo);
array = tempArray;
offset = 0;
}

WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
return task;
}

}
else
{
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
// Else the remainder of the string will fit into the buffer, so copy it into the
// buffer and then break out of the loop.

Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

// handle out buffer bytes used counter
_outBytesUsed += len;
Expand Down Expand Up @@ -1021,7 +1062,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
}

// This is in its own method to avoid always allocating the lambda in WriteByteArray
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
private void WriteBytesSetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
{
AsyncHelper.ContinueTask(packetTask, completion,
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),
Expand Down
Loading
Loading