diff --git a/src/NATS.Client/Conn.cs b/src/NATS.Client/Conn.cs index 4ef90b79c..7bbd8912e 100644 --- a/src/NATS.Client/Conn.cs +++ b/src/NATS.Client/Conn.cs @@ -3890,6 +3890,13 @@ private void drain(int timeout) lock (mu) { + if (isClosed()) + throw new NATSConnectionClosedException(); + + // if we're already draining, exit. + if (isDrainingSubs() || isDrainingPubs()) + return; + lsubs = subs.Values; status = ConnState.DRAINING_SUBS; } @@ -3974,11 +3981,6 @@ public void Drain(int timeout) if (timeout <= 0) throw new ArgumentOutOfRangeException(nameof(timeout), "Timeout must be greater than zero."); - lock (mu) - { - status = ConnState.DRAINING_SUBS; - } - drain(timeout); } @@ -4019,11 +4021,6 @@ public Task DrainAsync(int timeout) if (timeout <= 0) throw new ArgumentOutOfRangeException(nameof(timeout), "Timeout must be greater than zero."); - lock (mu) - { - status = ConnState.DRAINING_SUBS; - } - return Task.Run(() => drain(timeout)); } diff --git a/src/Tests/IntegrationTests/TestConnection.cs b/src/Tests/IntegrationTests/TestConnection.cs index 2596bb764..f154951c7 100644 --- a/src/Tests/IntegrationTests/TestConnection.cs +++ b/src/Tests/IntegrationTests/TestConnection.cs @@ -1093,6 +1093,7 @@ public async Task TestDrainStateBehavior() { closed.Set(); }; + using (var c = Context.ConnectionFactory.CreateConnection(opts)) { using (c.SubscribeAsync("foo", (obj, args) => @@ -1109,6 +1110,8 @@ public async Task TestDrainStateBehavior() // give us a long timeout to run our test. var drainTask = c.DrainAsync(10000); + // Sleep a bit to ensure the drain task is running. + Thread.Sleep(100); Assert.True(c.State == ConnState.DRAINING_SUBS); Assert.True(c.IsDraining()); @@ -1124,6 +1127,12 @@ public async Task TestDrainStateBehavior() Assert.True(closed.WaitOne(10000)); } } + + // Now test connection state checking in drain after being closed via API. + var conn = Context.ConnectionFactory.CreateConnection(opts); + conn.Close(); + _ = Assert.Throws(() => conn.Drain()); + await Assert.ThrowsAsync(() => { return conn.DrainAsync(); }); } }