Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate SSPI context generation to single abstraction #2255

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,18 @@
<Compile Include="$(CommonSourceRoot)System\Diagnostics\CodeAnalysis.cs">
<Link>Common\System\Diagnostics\CodeAnalysis.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\TdsParser.cs">
<Link>Microsoft\Data\SqlClient\TdsParser.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)\Microsoft\Data\ProviderBase\DbReferenceCollection.cs">
<Link>Microsoft\Data\ProviderBase\DbReferenceCollection.cs</Link>
</Compile>
Expand Down Expand Up @@ -771,6 +783,9 @@
<Compile Include="$(CommonPath)\Interop\Windows\kernel32\Interop.LoadLibraryEx.cs">
<Link>Common\Interop\Windows\kernel32\Interop.LoadLibraryEx.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="Interop\SNINativeMethodWrapper.Windows.cs" />
<Compile Include="Microsoft\Data\ProviderBase\DbConnectionPoolIdentity.Windows.cs" />
<Compile Include="Microsoft\Data\SqlClient\LocalDBAPI.Windows.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ namespace Microsoft.Data.SqlClient
{
internal sealed partial class TdsParser
{
private static volatile bool s_fSSPILoaded = false; // bool to indicate whether library has been loaded

internal void PostReadAsyncForMars()
{
if (TdsParserStateObjectFactory.UseManagedSNI)
Expand Down Expand Up @@ -43,37 +41,7 @@ internal void PostReadAsyncForMars()
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
ThrowExceptionAndWarning(_physicalStateObj);
}
}

private void LoadSSPILibrary()
{
if (TdsParserStateObjectFactory.UseManagedSNI)
return;
// Outer check so we don't acquire lock once it's loaded.
if (!s_fSSPILoaded)
{
lock (s_tdsParserLock)
{
// re-check inside lock
if (!s_fSSPILoaded)
{
// use local for ref param to defer setting s_maxSSPILength until we know the call succeeded.
uint maxLength = 0;

if (0 != SNINativeMethodWrapper.SNISecInitPackage(ref maxLength))
SSPIError(SQLMessage.SSPIInitializeError(), TdsEnums.INIT_SSPI_PACKAGE);

s_maxSSPILength = maxLength;
s_fSSPILoaded = true;
}
}
}

if (s_maxSSPILength > int.MaxValue)
{
throw SQL.InvalidSSPIPacketSize(); // SqlBu 332503
}
}
}

private void WaitForSSLHandShakeToComplete(ref uint error, ref int protocolVersion)
{
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon
AddError(parser.ProcessSNIError(this));
ThrowExceptionAndWarning();
}

// we post a callback that represents the call to dispose; once the
// object is disposed, the next callback will cause the GC Handle to
// be released.
Expand All @@ -75,6 +75,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon
////////////////
internal abstract uint DisableSsl();

internal abstract SSPIContextProvider CreateSSPIContextProvider();

internal abstract uint EnableMars(ref uint info);

Expand All @@ -83,6 +84,8 @@ internal abstract uint Status
get;
}

internal abstract Guid? SessionId { get; }

internal abstract SessionHandle SessionHandle
{
get;
Expand Down Expand Up @@ -236,8 +239,6 @@ internal abstract void CreatePhysicalSNIHandle(

protected abstract void RemovePacketFromPendingList(PacketHandle pointer);

internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);

internal int DecrementPendingCallbacks(bool release)
{
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ internal sealed class TdsParserStateObjectManaged : TdsParserStateObject
{
private SNIMarsConnection? _marsConnection;
private SNIHandle? _sessionHandle;
#if NET7_0_OR_GREATER
private NegotiateAuthentication? _negotiateAuth = null;
#else
private SspiClientContextStatus? _sspiClientContextStatus;
#endif

public TdsParserStateObjectManaged(TdsParser parser) : base(parser) { }

internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject physicalConnection, bool async) :
Expand Down Expand Up @@ -232,6 +228,8 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint

protected override PacketHandle EmptyReadPacket => PacketHandle.FromManagedPacket(null);

internal override Guid? SessionId => _sessionHandle?.ConnectionId;

internal override bool IsPacketEmpty(PacketHandle packet) => packet.ManagedPacket == null;

internal override void ReleasePacket(PacketHandle syncReadPacket)
Expand Down Expand Up @@ -389,30 +387,6 @@ internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
return TdsEnums.SNI_SUCCESS;
}

internal override uint GenerateSspiClientContext(byte[] receivedBuff,
uint receivedLength,
ref byte[] sendBuff,
ref uint sendLength,
byte[][] _sniSpnBuffer)
{
#if NET7_0_OR_GREATER
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[0]) });
sendBuff = _negotiateAuth.GetOutgoingBlob(receivedBuff, out NegotiateAuthenticationStatusCode statusCode)!;
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}, StatusCode={1}", _sessionHandle?.ConnectionId, statusCode);
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
}
#else
_sspiClientContextStatus ??= new SspiClientContextStatus();

SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _sessionHandle?.ConnectionId);
#endif
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
return 0;
}

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
protocolVersion = GetSessionSNIHandleHandleOrThrow().ProtocolVersion;
Expand All @@ -432,5 +406,12 @@ private SNIHandle GetSessionSNIHandleHandleOrThrow()
[DoesNotReturn]
[MethodImpl(MethodImplOptions.NoInlining)] // this forces the exception throwing code not to be inlined for performance
private void ThrowClosedConnection() => throw ADP.ClosedConnectionError();

internal override SSPIContextProvider CreateSSPIContextProvider()
#if NET7_0_OR_GREATER
=> new NegotiateSSPIContextProvider();
#else
=> new ManagedSSPIContextProvider();
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ internal override void CreatePhysicalSNIHandle(
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.",nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
Expand Down Expand Up @@ -272,6 +272,8 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint

protected override PacketHandle EmptyReadPacket => PacketHandle.FromNativePointer(default);

internal override Guid? SessionId => default;

internal override bool IsPacketEmpty(PacketHandle readPacket)
{
Debug.Assert(readPacket.Type == PacketHandle.NativePointerType || readPacket.Type == 0, "unexpected packet type when requiring NativePointer");
Expand Down Expand Up @@ -398,9 +400,6 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
uint returnValue = SNINativeMethodWrapper.SNIWaitForSSLHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
Expand Down Expand Up @@ -452,6 +451,8 @@ internal override void DisposePacketCache()
}
}

internal override SSPIContextProvider CreateSSPIContextProvider() => new NativeSSPIContextProvider();

internal sealed class WritePacketCache : IDisposable
{
private bool _disposed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\ProviderBase\TimeoutTimer.cs">
<Link>Microsoft\Data\ProviderBase\TimeoutTimer.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\ManagedSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NativeSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\NegotiateSSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs">
<Link>Microsoft\Data\SqlClient\SSPI\SSPIContextProvider.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\TdsParser.cs">
<Link>Microsoft\Data\SqlClient\TdsParser.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
</Compile>
Expand Down Expand Up @@ -751,4 +766,4 @@
<Import Project="$(NetFxSource)tools\targets\GenerateThisAssemblyCs.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyRef.targets" />
<Import Project="$(NetFxSource)tools\targets\GenerateAssemblyInfo.targets" />
</Project>
</Project>
Loading
Loading