Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve efficiency and correctness of SSL handshake #1949

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
return context;
}

internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[] sendBuf, out int sendCount)
internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, ref byte[] sendBuf, out int sendCount)
{
sendBuf = null;
sendCount = 0;
Expand Down Expand Up @@ -286,7 +286,10 @@ internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> in
sendCount = Crypto.BioCtrlPending(context.OutputBio);
if (sendCount > 0)
{
sendBuf = new byte[sendCount];
if (sendBuf == null || sendBuf.Length < sendCount)
{
sendBuf = new byte[sendCount];
}

try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ internal interface ISSPIInterface
int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential);
int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.CredentialUse usage, ref Interop.SspiCli.SCHANNEL_CRED authdata, out SafeFreeCredentials outCredential);
int AcquireDefaultCredential(string moduleName, Interop.SspiCli.CredentialUse usage, out SafeFreeCredentials outCredential);
int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan<SecurityBuffer> inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan<SecurityBuffer> inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags);
int EncryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber);
int DecryptMessage(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber);
int MakeSignature(SafeDeleteContext context, ref Interop.SspiCli.SecBufferDesc inputOutput, uint sequenceNumber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ public int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Credentia
return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential);
}

public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan<SecurityBuffer> inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
}

public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan<SecurityBuffer> inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ public int AcquireCredentialsHandle(string moduleName, Interop.SspiCli.Credentia
return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential);
}

public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, ReadOnlySpan<SecurityBuffer> inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteSslContext context, InputSecurityBuffers inputBuffers, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
}

public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, ReadOnlySpan<SecurityBuffer> inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness endianness, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffers, ref outputBuffer, ref outFlags);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,24 @@ public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secMod
return outCredential;
}

internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, ReadOnlySpan<SecurityBuffer> inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
internal static int InitializeSecurityContext(ISSPIInterface secModule, ref SafeFreeCredentials credential, ref SafeDeleteSslContext context, string targetName, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
if (NetEventSource.IsEnabled) NetEventSource.Log.InitializeSecurityContext(credential, context, targetName, inFlags);

int errorCode = secModule.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, datarep, inputBuffers, ref outputBuffer, ref outFlags);

if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Length, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);
if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(InitializeSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);

return errorCode;
}

internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteSslContext context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, ReadOnlySpan<SecurityBuffer> inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteSslContext context, Interop.SspiCli.ContextFlags inFlags, Interop.SspiCli.Endianness datarep, InputSecurityBuffers inputBuffers, ref SecurityBuffer outputBuffer, ref Interop.SspiCli.ContextFlags outFlags)
{
if (NetEventSource.IsEnabled) NetEventSource.Log.AcceptSecurityContext(credential, context, inFlags);

int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, ref outputBuffer, ref outFlags);

if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Length, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);
if (NetEventSource.IsEnabled) NetEventSource.Log.SecurityContextInputBuffers(nameof(AcceptSecurityContext), inputBuffers.Count, outputBuffer.size, (Interop.SECURITY_STATUS)errorCode);

return errorCode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ internal static unsafe int InitializeSecurityContext(
string targetName,
Interop.SspiCli.ContextFlags inFlags,
Interop.SspiCli.Endianness endianness,
ReadOnlySpan<SecurityBuffer> inSecBuffers,
InputSecurityBuffers inSecBuffers,
ref SecurityBuffer outSecBuffer,
ref Interop.SspiCli.ContextFlags outFlags)
{
Expand All @@ -413,7 +413,7 @@ internal static unsafe int InitializeSecurityContext(
throw new ArgumentNullException(nameof(inCredentials));
}

Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length);
Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Count);
Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(1);

// Actually, this is returned in outFlags.
Expand All @@ -429,36 +429,31 @@ internal static unsafe int InitializeSecurityContext(

// Optional output buffer that may need to be freed.
SafeFreeContextBuffer outFreeContextBuffer = null;
Span<IntPtr> ptr = stackalloc IntPtr[3];
wfurt marked this conversation as resolved.
Show resolved Hide resolved
wfurt marked this conversation as resolved.
Show resolved Hide resolved
try
{
Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers];
Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecBuffers.Count];
inUnmanagedBuffer.Clear();

fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer)
fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null)
fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null)
fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles
fixed (void* pinnedToken0 = &inSecBuffers.GetBuffer(0).Token.GetPinnableReference())
fixed (void* pinnedToken1 = &inSecBuffers.GetBuffer(1).Token.GetPinnableReference())
fixed (void* pinnedToken2 = &inSecBuffers.GetBuffer(2).Token.GetPinnableReference())
{
Debug.Assert(inSecBuffers.Length <= 3);
ptr[0] = (IntPtr)pinnedToken0;
ptr[1] = (IntPtr)pinnedToken1;
ptr[2] = (IntPtr)pinnedToken2;
wfurt marked this conversation as resolved.
Show resolved Hide resolved

// Fix Descriptor pointer that points to unmanaged SecurityBuffers.
inSecurityBufferDescriptor.pBuffers = inUnmanagedBufferPtr;
for (int index = 0; index < inSecurityBufferDescriptor.cBuffers; ++index)
{
ref readonly SecurityBuffer securityBuffer = ref inSecBuffers[index];

// Copy the SecurityBuffer content into unmanaged place holder.
inUnmanagedBuffer[index].cbBuffer = securityBuffer.size;
inUnmanagedBuffer[index].BufferType = securityBuffer.type;

// Use the unmanaged token if it's not null; otherwise use the managed buffer.
inUnmanagedBuffer[index].pvBuffer =
securityBuffer.unmanagedToken != null ? securityBuffer.unmanagedToken.DangerousGetHandle() :
securityBuffer.token == null || securityBuffer.token.Length == 0 ? IntPtr.Zero :
Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset);
#if TRACE_VERBOSE
if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"SecBuffer: cbBuffer:{securityBuffer.size} BufferType:{securityBuffer.type}");
#endif
for (int index = 0; index < inSecBuffers.Count; ++index)
{
inUnmanagedBuffer[index].cbBuffer = inSecBuffers.GetBuffer(index).Token.Length;
inUnmanagedBuffer[index].BufferType = inSecBuffers.GetBuffer(index).Type;
inUnmanagedBuffer[index].pvBuffer = inSecBuffers.GetBuffer(index).UnmanagedToken != null ?
inSecBuffers.GetBuffer(index).UnmanagedToken.DangerousGetHandle() :
ptr[index];
}

fixed (byte* pinnedOutBytes = outSecBuffer.token)
Expand Down Expand Up @@ -626,7 +621,7 @@ internal static unsafe int AcceptSecurityContext(
ref SafeDeleteSslContext refContext,
Interop.SspiCli.ContextFlags inFlags,
Interop.SspiCli.Endianness endianness,
ReadOnlySpan<SecurityBuffer> inSecBuffers,
InputSecurityBuffers inSecBuffers,
ref SecurityBuffer outSecBuffer,
ref Interop.SspiCli.ContextFlags outFlags)
{
Expand All @@ -643,7 +638,8 @@ internal static unsafe int AcceptSecurityContext(
throw new ArgumentNullException(nameof(inCredentials));
}

Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Length);
Debug.Assert(inSecBuffers.Count <= 3);
Interop.SspiCli.SecBufferDesc inSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(inSecBuffers.Count);
Interop.SspiCli.SecBufferDesc outSecurityBufferDescriptor = new Interop.SspiCli.SecBufferDesc(count: 2);

// Actually, this is returned in outFlags.
Expand All @@ -661,37 +657,33 @@ internal static unsafe int AcceptSecurityContext(
SafeFreeContextBuffer outFreeContextBuffer = null;
Span<Interop.SspiCli.SecBuffer> outUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[2];
outUnmanagedBuffer[1].pvBuffer = IntPtr.Zero;
Span<IntPtr> ptr = stackalloc IntPtr[3];
try
{
Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecurityBufferDescriptor.cBuffers];
Span<Interop.SspiCli.SecBuffer> inUnmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[inSecBuffers.Count];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the intention is for the stackalloc to be 1..3 entries long but there's no assertion or check that i can see and since count is unconditionally increased in SetNextBuffer it could in theory any value if someone found a way to abuse it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share an example of where/when it wouldn't be a small value?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing obvious, as i said the intention is clearly that it should be 1..3 which causes no issue. I just thought it might be worth an assert or check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Sure, adding an assert would be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is Assert in SetNextBuffer so it won't go more than 3. I can also make it consistent and always allocate 3 since the rest of the logic is fixed anyway. I did not find a good way how to make the equivalent of fixed() in a more flexible way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if hot path but having a constant-sized stack via (stackalloc Interop.SspiCli.SecBuffer[3])[0..inSecBuffers.Count] may use fewer instructions, if you do have this upper bound.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, if the max is actually 3, just use 3. The JIT is better at optimizing const-sized stackallocs. There's also no reason to slice it beyond that unless you actually use the length later. If you do use the length later, just stick with the variable-sized stackalloc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated code to use const-sized allocation since the difference between 0 and 3 is really small. (and will be 2 in most cases)
I was wondering if '3' should be some named constant instead of magic number but I not sure what the name should be and where it should live.


inUnmanagedBuffer.Clear();

fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer)
fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer)
fixed (void* pinnedToken0 = inSecBuffers.Length > 0 ? inSecBuffers[0].token : null)
fixed (void* pinnedToken1 = inSecBuffers.Length > 1 ? inSecBuffers[1].token : null)
fixed (void* pinnedToken2 = inSecBuffers.Length > 2 ? inSecBuffers[2].token : null) // pin all buffers, even if null or not used, to avoid needing to allocate GCHandles
fixed (void* pinnedToken0 = &inSecBuffers.GetBuffer(0).Token.GetPinnableReference())
fixed (void* pinnedToken1 = &inSecBuffers.GetBuffer(1).Token.GetPinnableReference())
fixed (void* pinnedToken2 = &inSecBuffers.GetBuffer(2).Token.GetPinnableReference())
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
Debug.Assert(inSecBuffers.Length <= 3);

// Fix Descriptor pointer that points to unmanaged SecurityBuffers.
ptr[0] = (IntPtr)pinnedToken0;
ptr[1] = (IntPtr)pinnedToken1;
ptr[2] = (IntPtr)pinnedToken2;
wfurt marked this conversation as resolved.
Show resolved Hide resolved

inSecurityBufferDescriptor.pBuffers = inUnmanagedBufferPtr;
for (int index = 0; index < inSecurityBufferDescriptor.cBuffers; ++index)
{
ref readonly SecurityBuffer securityBuffer = ref inSecBuffers[index];

// Copy the SecurityBuffer content into unmanaged place holder.
inUnmanagedBuffer[index].cbBuffer = securityBuffer.size;
inUnmanagedBuffer[index].BufferType = securityBuffer.type;

// Use the unmanaged token if it's not null; otherwise use the managed buffer.
inUnmanagedBuffer[index].pvBuffer =
securityBuffer.unmanagedToken != null ? securityBuffer.unmanagedToken.DangerousGetHandle() :
securityBuffer.token == null || securityBuffer.token.Length == 0 ? IntPtr.Zero :
Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset);
#if TRACE_VERBOSE
if (NetEventSource.IsEnabled) NetEventSource.Info(null, $"SecBuffer: cbBuffer:{securityBuffer.size} BufferType:{securityBuffer.type}");
#endif
for (int index = 0; index < inSecBuffers.Count; ++index)
{
inUnmanagedBuffer[index].cbBuffer = inSecBuffers.GetBuffer(index).Token.Length;
inUnmanagedBuffer[index].BufferType = inSecBuffers.GetBuffer(index).Type;
inUnmanagedBuffer[index].pvBuffer = inSecBuffers.GetBuffer(index).UnmanagedToken != null ?
inSecBuffers.GetBuffer(index).UnmanagedToken.DangerousGetHandle() :
ptr[index];
}

fixed (byte* pinnedOutBytes = outSecBuffer.token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace System.Net.Http
namespace System.Net
{
// Warning: Mutable struct!
// The purpose of this struct is to simplify buffer management.
Expand Down Expand Up @@ -54,7 +54,7 @@ public void Dispose()
}

public int ActiveLength => _availableStart - _activeStart;
public Span<byte> ActiveSpan => new Span<byte>(_bytes, _activeStart, _availableStart - _activeStart);
public ReadOnlySpan<byte> ActiveSpan => new Span<byte>(_bytes, _activeStart, _availableStart - _activeStart);
wfurt marked this conversation as resolved.
Show resolved Hide resolved
public int AvailableLength => _bytes.Length - _availableStart;
public Span<byte> AvailableSpan => new Span<byte>(_bytes, _availableStart, AvailableLength);
public Memory<byte> ActiveMemory => new Memory<byte>(_bytes, _activeStart, _availableStart - _activeStart);
Expand Down
Loading