From 21b52de93f0ded611582cf3a86b02ec1ce4e77dd Mon Sep 17 00:00:00 2001 From: Christoph Briem Date: Tue, 14 May 2024 12:24:37 +0200 Subject: [PATCH] fix: close() call on already closed S3WritableByteChannel (awslabs#453) --- .../nio/spi/s3/S3WritableByteChannel.java | 4 ++++ .../nio/spi/s3/S3WritableByteChannelTest.java | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/main/java/software/amazon/nio/spi/s3/S3WritableByteChannel.java b/src/main/java/software/amazon/nio/spi/s3/S3WritableByteChannel.java index 6cc74b52..21ffc41f 100644 --- a/src/main/java/software/amazon/nio/spi/s3/S3WritableByteChannel.java +++ b/src/main/java/software/amazon/nio/spi/s3/S3WritableByteChannel.java @@ -87,6 +87,10 @@ public boolean isOpen() { @Override public void close() throws IOException { channel.close(); + if (!open) { + // channel has already been closed -> close() should have no effect + return; + } s3TransferUtil.uploadLocalFile(path, tempFile); Files.deleteIfExists(tempFile); diff --git a/src/test/java/software/amazon/nio/spi/s3/S3WritableByteChannelTest.java b/src/test/java/software/amazon/nio/spi/s3/S3WritableByteChannelTest.java index ffac2afd..fe920600 100644 --- a/src/test/java/software/amazon/nio/spi/s3/S3WritableByteChannelTest.java +++ b/src/test/java/software/amazon/nio/spi/s3/S3WritableByteChannelTest.java @@ -30,6 +30,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -125,6 +127,24 @@ void tmpFileIsCleanedUpAfterClose(@TempDir Path tempDir) throws InterruptedExcep assertThat(countAfterClosing).isLessThan(countAfterOpening); } + @Test + @DisplayName("second close() call should be a no-op") + void secondCloseIsNoOp() throws InterruptedException, TimeoutException, IOException { + S3FileSystemProvider provider = mock(); + when(provider.exists(any(S3AsyncClient.class), any())).thenReturn(false); + S3FileSystem fs = mock(); + when(fs.provider()).thenReturn(provider); + var file = S3Path.getPath(fs, "somefile"); + + S3TransferUtil utilMock = mock(); + var channel = new S3WritableByteChannel(file, mock(), utilMock, Set.of(CREATE)); + channel.close(); + // this close() call should be a no-op + channel.close(); + + verify(utilMock, times(1)).uploadLocalFile(any(), any()); + } + private long countTemporaryFiles(Path tempDir) throws IOException { try (var list = Files.list(tempDir.getParent())) { return list