Skip to content

Commit

Permalink
Merge pull request #372 from nats-io/auto-tls
Browse files Browse the repository at this point in the history
Auto TLS
  • Loading branch information
ColinSullivan1 authored Apr 3, 2020
2 parents 063724b + c2895f8 commit a99ae1d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 26 deletions.
25 changes: 14 additions & 11 deletions src/NATS.Client/Conn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ internal Stream getReadBufferedStream()

internal Stream getWriteBufferedStream(int size)
{
BufferedStream bs = null;

BufferedStream bs = null;
if (sslStream != null)
bs = new BufferedStream(sslStream, size);
else
Expand Down Expand Up @@ -1044,7 +1044,7 @@ public string[] DiscoveredServers

// Process a connected connection and initialize properly.
// Caller must lock.
private void processConnectInit()
private void processConnectInit(Srv s)
{
this.status = ConnState.CONNECTING;

Expand All @@ -1053,7 +1053,7 @@ private void processConnectInit()
try
{
conn.ReceiveTimeout = opts.Timeout;
processExpectedInfo();
processExpectedInfo(s);
sendConnect();
}
catch (IOException ex)
Expand Down Expand Up @@ -1099,7 +1099,7 @@ internal bool connect(Srv s, out Exception exToThrow)
if (!createConn(s))
return false;

processConnectInit();
processConnectInit(s);
exToThrow = null;

return true;
Expand Down Expand Up @@ -1162,7 +1162,7 @@ internal void connect()
// This will check to see if the connection should be
// secure. This can be dictated from either end and should
// only be called after the INIT protocol has been received.
private void checkForSecure()
private void checkForSecure(Srv s)
{
// Check to see if we need to engage TLS
// Check for mismatch in setups
Expand All @@ -1172,11 +1172,14 @@ private void checkForSecure()
}
else if (info.tls_required && !Opts.Secure)
{
throw new NATSSecureConnRequiredException();
// If the server asks us to be secure, give it
// a shot.
Opts.Secure = true;
}

// Need to rewrap with bufio
if (Opts.Secure)
// Need to rewrap with bufio if options tell us we need
// a secure connection or the tls url scheme was specified.
if (Opts.Secure || s.Secure)
{
makeTLSConn();
}
Expand All @@ -1185,7 +1188,7 @@ private void checkForSecure()
// processExpectedInfo will look for the expected first INFO message
// sent when a connection is established. The lock should be held entering.
// Caller must lock.
private void processExpectedInfo()
private void processExpectedInfo(Srv s)
{
Control c;

Expand All @@ -1211,7 +1214,7 @@ private void processExpectedInfo()

// do not notify listeners of server changes when we process the first INFO message
processInfo(c.args, false);
checkForSecure();
checkForSecure(s);
}

private void writeString(string format, string a, string b)
Expand Down Expand Up @@ -1597,7 +1600,7 @@ private void doReconnect()
// process our connect logic
try
{
processConnectInit();
processConnectInit(cur);
}
catch (Exception e)
{
Expand Down
7 changes: 7 additions & 0 deletions src/NATS.Client/Srv.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ internal class Srv
internal int reconnects = 0;
internal DateTime lastAttempt = DateTime.Now;
internal bool isImplicit = false;
private bool secure = false;

// never create a srv object without a url.
private Srv() { }
Expand All @@ -36,6 +37,10 @@ internal Srv(string urlString)
{
urlString = defaultScheme + urlString;
}
else
{
secure = urlString.Contains("tls://");
}

var uri = new Uri(urlString);

Expand All @@ -59,6 +64,8 @@ internal TimeSpan TimeSinceLastAttempt
return (DateTime.Now - lastAttempt);
}
}

internal bool Secure => secure;
}
}

26 changes: 13 additions & 13 deletions src/Tests/IntegrationTests/TestAsyncAwaitDeadlocks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void EnsureDrain()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -68,7 +68,7 @@ public void EnsureDrainAsync()

AsyncContext.Run(async () =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -102,7 +102,7 @@ public void EnsureRequestResponder()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand All @@ -126,7 +126,7 @@ public void EnsureRequestAsyncResponder()

AsyncContext.Run(async () =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand All @@ -150,7 +150,7 @@ public void EnsureRequestResponderWithFlush()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand Down Expand Up @@ -179,7 +179,7 @@ public void EnsureRequestAsyncResponderWithFlush()

AsyncContext.Run(async () =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand Down Expand Up @@ -208,7 +208,7 @@ public void EnsurePubSub()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -242,7 +242,7 @@ public void EnsurePubSubWithFlush()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -280,7 +280,7 @@ public void EnsurePubSubWithAsyncHandler()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -314,7 +314,7 @@ public void EnsureAutoUnsubscribeForSyncSub()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand Down Expand Up @@ -345,7 +345,7 @@ public void EnsureAutoUnsubscribeForAsyncSub()

AsyncContext.Run(async () =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down Expand Up @@ -383,7 +383,7 @@ public void EnsureUnsubscribeForSyncSub()

AsyncContext.Run(() =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var cn = Context.ConnectionFactory.CreateConnection(Context.Server1.Url))
{
Expand Down Expand Up @@ -414,7 +414,7 @@ public void EnsureUnsubscribeForAsyncSub()

AsyncContext.Run(async () =>
{
using (NATSServer.CreateFast(Context.Server1.Port))
using (NATSServer.CreateFastAndVerify(Context.Server1.Port))
{
using (var sync = TestSync.SingleActor())
{
Expand Down
38 changes: 36 additions & 2 deletions src/Tests/IntegrationTests/TestTLS.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,49 @@ public void TestTlsFailWithBadAuth()
}
}

private void TestTLSSecureConnect(bool setSecure)
{
using (NATSServer srv = NATSServer.CreateWithConfig(Context.Server1.Port, "tls.conf"))
{
// we can't call create secure connection w/ the certs setup as they are
// so we'll override the validation callback
Options opts = Context.GetTestOptions(Context.Server1.Port);
opts.Secure = setSecure;
opts.TLSRemoteCertificationValidationCallback = verifyServerCert;

using (IConnection c = Context.ConnectionFactory.CreateConnection(opts))
{
using (ISyncSubscription s = c.SubscribeSync("foo"))
{
c.Publish("foo", null);
c.Flush();
Msg m = s.NextMessage();
}
}
}
}

[Fact]
public void TestTlsSuccessSecureConnect()
{
TestTLSSecureConnect(true);
}

[Fact]
public void TestTlsSuccessSecureConnectFromServerInfo()
{
TestTLSSecureConnect(false);
}

[Fact]
public void TestTlsScheme()
{
using (NATSServer srv = NATSServer.CreateWithConfig(Context.Server1.Port, "tls.conf"))
{
// we can't call create secure connection w/ the certs setup as they are
// so we'll override the
// so we'll override the validation callback
Options opts = Context.GetTestOptions(Context.Server1.Port);
opts.Secure = true;
opts.Url = $"tls://127.0.0.1:{Context.Server1.Port}";
opts.TLSRemoteCertificationValidationCallback = verifyServerCert;

using (IConnection c = Context.ConnectionFactory.CreateConnection(opts))
Expand Down

0 comments on commit a99ae1d

Please sign in to comment.