diff --git a/src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java b/src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java index 5fd164eb..6fc778b9 100644 --- a/src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java +++ b/src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java @@ -139,13 +139,14 @@ S3AsyncClient generateClient(String bucketName, S3AsyncClient locationClient) if (client != null && !client.isClosed()) { return client; } else { + if (client != null && client.isClosed()) { + bucketClientCache.invalidate(bucketName); // remove the closed client from the cache + } String r = Optional.ofNullable(bucketLocation).orElse(configuration.getRegion()); return bucketClientCache.get(bucketName, b -> new CacheableS3Client(configureCrtClientForRegion(r))); } - } - private String getBucketLocation(String bucketName, S3AsyncClient locationClient) throws ExecutionException, InterruptedException { diff --git a/src/test/java/software/amazon/nio/spi/s3/S3ClientProviderTest.java b/src/test/java/software/amazon/nio/spi/s3/S3ClientProviderTest.java index fe2f14ce..c861ba4f 100644 --- a/src/test/java/software/amazon/nio/spi/s3/S3ClientProviderTest.java +++ b/src/test/java/software/amazon/nio/spi/s3/S3ClientProviderTest.java @@ -5,7 +5,9 @@ package software.amazon.nio.spi.s3; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -44,11 +46,11 @@ public void before() { @Test public void initialization() { - final var P = new S3ClientProvider(null); + final var s3ClientProvider = new S3ClientProvider(null); - assertNotNull(P.configuration); + assertNotNull(s3ClientProvider.configuration); - S3AsyncClient t = P.universalClient(); + S3AsyncClient t = s3ClientProvider.universalClient(); assertNotNull(t); var config = new S3NioSpiConfiguration(); @@ -64,6 +66,45 @@ public void testGenerateAsyncClientWithNoErrors() throws ExecutionException, Int assertNotNull(s3Client); } + @Test + public void testGenerateClientIsCacheableClass() throws Exception { + when(mockClient.headBucket(anyConsumer())) + .thenReturn(CompletableFuture.completedFuture( + HeadBucketResponse.builder().bucketRegion("us-west-2").build())); + final var s3Client = provider.generateClient("test-bucket", mockClient); + assertInstanceOf(CacheableS3Client.class, s3Client); + } + + @Test + public void testGenerateClientCachesClients() throws Exception { + when(mockClient.headBucket(anyConsumer())) + .thenReturn(CompletableFuture.completedFuture( + HeadBucketResponse.builder().bucketRegion("us-west-2").build())); + final var s3Client = provider.generateClient("test-bucket", mockClient); + final var s3Client2 = provider.generateClient("test-bucket", mockClient); + assertSame(s3Client, s3Client2); + } + + @Test + public void testClosedClientIsNotReused() throws ExecutionException, InterruptedException { + when(mockClient.headBucket(anyConsumer())) + .thenReturn(CompletableFuture.completedFuture( + HeadBucketResponse.builder().bucketRegion("us-west-2").build())); + + final var s3Client = provider.generateClient("test-bucket", mockClient); + assertNotNull(s3Client); + + // now close the client + s3Client.close(); + + // now generate a new client with the same bucket name + final var s3Client2 = provider.generateClient("test-bucket", mockClient); + assertNotNull(s3Client2); + + // assert it is not the closed client + assertNotSame(s3Client, s3Client2); + } + @Test public void testGenerateAsyncClientWith403Response() throws ExecutionException, InterruptedException { // when you get a forbidden response from HeadBucket