diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs index 1c8e173f919ec..a51927e915d48 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/Interop/SafeMsQuicConfigurationHandle.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Net.Security; +using System.Reflection; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Text; @@ -15,6 +16,9 @@ namespace System.Net.Quic.Implementations.MsQuic.Internal { internal sealed class SafeMsQuicConfigurationHandle : SafeHandle { + private static readonly FieldInfo _contextCertificate = typeof(SslStreamCertificateContext).GetField("Certificate", BindingFlags.NonPublic | BindingFlags.Instance)!; + private static readonly FieldInfo _contextChain= typeof(SslStreamCertificateContext).GetField("IntermediateCertificates", BindingFlags.NonPublic | BindingFlags.Instance)!; + public override bool IsInvalid => handle == IntPtr.Zero; private SafeMsQuicConfigurationHandle() @@ -31,18 +35,18 @@ protected override bool ReleaseHandle() public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options) { // TODO: lots of ClientAuthenticationOptions are not yet supported by MsQuic. - return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: null, options.ClientAuthenticationOptions?.ApplicationProtocols); + return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: null, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols); } public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options) { // TODO: lots of ServerAuthenticationOptions are not yet supported by MsQuic. - return Create(options, QUIC_CREDENTIAL_FLAGS.NONE, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ApplicationProtocols); + return Create(options, QUIC_CREDENTIAL_FLAGS.NONE, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols); } // TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer. // Consider moving bigger logic like this outside of constructor call chains. - private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, List? alpnProtocols) + private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, SslStreamCertificateContext? certificateContext, List? alpnProtocols) { // TODO: some of these checks should be done by the QuicOptions type. if (alpnProtocols == null || alpnProtocols.Count == 0) @@ -62,7 +66,7 @@ private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, if ((flags & QUIC_CREDENTIAL_FLAGS.CLIENT) == 0) { - if (certificate == null) + if (certificate == null && certificateContext == null) { throw new Exception("Server must provide certificate"); } @@ -101,6 +105,7 @@ private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, uint status; SafeMsQuicConfigurationHandle? configurationHandle; + X509Certificate2[]? intermediates = null; MemoryHandle[]? handles = null; QuicBuffer[]? buffers = null; @@ -121,6 +126,17 @@ private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, CredentialConfig config = default; config.Flags = flags; // TODO: consider using LOAD_ASYNCHRONOUS with a callback. + if (certificateContext != null) + { + certificate = (X509Certificate2?) _contextCertificate.GetValue(certificateContext); + intermediates = (X509Certificate2[]?) _contextChain.GetValue(certificateContext); + + if (certificate == null || intermediates == null) + { + throw new ArgumentException(nameof(certificateContext)); + } + } + if (certificate != null) { if (OperatingSystem.IsWindows()) @@ -132,7 +148,24 @@ private static unsafe SafeMsQuicConfigurationHandle Create(QuicOptions options, else { CredentialConfigCertificatePkcs12 pkcs12Config; - byte[] asn1 = certificate.Export(X509ContentType.Pkcs12); + byte[] asn1; + + if (intermediates?.Length > 0) + { + X509Certificate2Collection collection = new X509Certificate2Collection(); + collection.Add(certificate); + for (int i= 0; i < intermediates?.Length; i++) + { + collection.Add(intermediates[i]); + } + + asn1 = collection.Export(X509ContentType.Pkcs12)!; + } + else + { + asn1 = certificate.Export(X509ContentType.Pkcs12); + } + fixed (void* ptr = asn1) { pkcs12Config.Asn1Blob = (IntPtr)ptr; diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 6d6431fc728af..694bfe17aacd3 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -235,7 +235,7 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti additionalCertificates.Import(asn1); if (additionalCertificates.Count > 0) { - certificate = additionalCertificates[0]; + certificate = additionalCertificates[additionalCertificates.Count - 1]; } } } @@ -263,7 +263,7 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti if (additionalCertificates != null && additionalCertificates.Count > 1) { - for (int i = 1; i < additionalCertificates.Count; i++) + for (int i = 0; i < additionalCertificates.Count - 1; i++) { chain.ChainPolicy.ExtraStore.Add(additionalCertificates[i]); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs index 446daf8021d8b..0fe06e9632830 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs @@ -5,6 +5,8 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Net.Security; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading.Tasks; using Xunit; @@ -53,6 +55,46 @@ public async Task UnidirectionalAndBidirectionalChangeValues() Assert.Equal(20, serverConnection.GetRemoteAvailableUnidirectionalStreamCount()); } + [Fact] + public async Task ConnectWithCertificateChain() + { + (X509Certificate2 certificate, X509Certificate2Collection chain) = System.Net.Security.Tests.TestHelper.GenerateCertificates("localhost", longChain: true); + X509Certificate2 rootCA = chain[chain.Count - 1]; + + var quicOptions = new QuicListenerOptions(); + quicOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0); + quicOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions(); + quicOptions.ServerAuthenticationOptions.ServerCertificateContext = SslStreamCertificateContext.Create(certificate, chain); + quicOptions.ServerAuthenticationOptions.ServerCertificate = null; + + using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, quicOptions); + + QuicClientConnectionOptions options = new QuicClientConnectionOptions() + { + RemoteEndPoint = listener.ListenEndPoint, + ClientAuthenticationOptions = GetSslClientAuthenticationOptions(), + }; + + options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) => + { + Assert.Equal(certificate.Subject, cert.Subject); + Assert.Equal(certificate.Issuer, cert.Issuer); + // We should get full chain without root CA. + // With trusted root, we should be able to build chain. + chain.ChainPolicy.CustomTrustStore.Add(rootCA); + chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + Assert.True(chain.Build(certificate)); + + return true; + }; + + using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + ValueTask clientTask = clientConnection.ConnectAsync(); + + using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); + await clientTask; + } + [Fact] [OuterLoop("May take several seconds")] public async Task SetListenerTimeoutWorksWithSmallTimeout() diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj b/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj index 504bc9d62f9f9..4803a74543198 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/System.Net.Quic.Functional.Tests.csproj @@ -8,13 +8,21 @@ - - - + + + + + + + + + + +