Skip to content

Commit

Permalink
API, AWS: Retry S3InputStream reads (#10433)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack Ye <[email protected]>
Co-authored-by: Xiaoxuan Li <[email protected]>
  • Loading branch information
3 people authored Sep 24, 2024
1 parent 72fd9ab commit c0d73f4
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 20 deletions.
71 changes: 56 additions & 15 deletions aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
*/
package org.apache.iceberg.aws.s3;

import dev.failsafe.Failsafe;
import dev.failsafe.FailsafeException;
import dev.failsafe.RetryPolicy;
import java.io.IOException;
import java.io.InputStream;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Arrays;
import javax.net.ssl.SSLException;
import org.apache.iceberg.exceptions.NotFoundException;
import org.apache.iceberg.io.FileIOMetricsContext;
import org.apache.iceberg.io.IOUtil;
Expand All @@ -31,6 +37,7 @@
import org.apache.iceberg.metrics.MetricsContext.Unit;
import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.io.ByteStreams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -57,6 +64,14 @@ class S3InputStream extends SeekableInputStream implements RangeReadable {
private final Counter readOperations;

private int skipSize = 1024 * 1024;
private RetryPolicy<Object> retryPolicy =
RetryPolicy.builder()
.handle(
ImmutableList.of(
SSLException.class, SocketTimeoutException.class, SocketException.class))
.onFailure(failure -> openStream(true))
.withMaxRetries(3)
.build();

S3InputStream(S3Client s3, S3URI location) {
this(s3, location, new S3FileIOProperties(), MetricsContext.nullMetrics());
Expand Down Expand Up @@ -92,27 +107,43 @@ public void seek(long newPos) {
public int read() throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();
try {
int bytesRead = Failsafe.with(retryPolicy).get(() -> stream.read());
pos += 1;
next += 1;
readBytes.increment();
readOperations.increment();

return bytesRead;
} catch (FailsafeException ex) {
if (ex.getCause() instanceof IOException) {
throw (IOException) ex.getCause();
}

pos += 1;
next += 1;
readBytes.increment();
readOperations.increment();

return stream.read();
throw ex;
}
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
Preconditions.checkState(!closed, "Cannot read: already closed");
positionStream();

int bytesRead = stream.read(b, off, len);
pos += bytesRead;
next += bytesRead;
readBytes.increment(bytesRead);
readOperations.increment();
try {
int bytesRead = Failsafe.with(retryPolicy).get(() -> stream.read(b, off, len));
pos += bytesRead;
next += bytesRead;
readBytes.increment(bytesRead);
readOperations.increment();

return bytesRead;
} catch (FailsafeException ex) {
if (ex.getCause() instanceof IOException) {
throw (IOException) ex.getCause();
}

return bytesRead;
throw ex;
}
}

@Override
Expand Down Expand Up @@ -146,7 +177,7 @@ private InputStream readRange(String range) {
public void close() throws IOException {
super.close();
closed = true;
closeStream();
closeStream(false);
}

private void positionStream() throws IOException {
Expand Down Expand Up @@ -178,6 +209,10 @@ private void positionStream() throws IOException {
}

private void openStream() throws IOException {
openStream(false);
}

private void openStream(boolean closeQuietly) throws IOException {
GetObjectRequest.Builder requestBuilder =
GetObjectRequest.builder()
.bucket(location.bucket())
Expand All @@ -186,7 +221,7 @@ private void openStream() throws IOException {

S3RequestUtil.configureEncryption(s3FileIOProperties, requestBuilder);

closeStream();
closeStream(closeQuietly);

try {
stream = s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream());
Expand All @@ -195,14 +230,20 @@ private void openStream() throws IOException {
}
}

private void closeStream() throws IOException {
private void closeStream(boolean closeQuietly) throws IOException {
if (stream != null) {
// if we aren't at the end of the stream, and the stream is abortable, then
// call abort() so we don't read the remaining data with the Apache HTTP client
abortStream();
try {
stream.close();
} catch (IOException e) {
if (closeQuietly) {
stream = null;
LOG.warn("An error occurred while closing the stream", e);
return;
}

// the Apache HTTP client will throw a ConnectionClosedException
// when closing an aborted stream, which is expected
if (!e.getClass().getSimpleName().equals("ConnectionClosedException")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.aws.s3;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;

import java.io.IOException;
import java.io.InputStream;
import java.net.SocketTimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import javax.net.ssl.SSLException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.core.sync.ResponseTransformer;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.CreateBucketResponse;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.HeadObjectRequest;
import software.amazon.awssdk.services.s3.model.HeadObjectResponse;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

public class TestFlakyS3InputStream extends TestS3InputStream {

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testRead(flakyStreamClient(new AtomicInteger(3), exception));
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testReadWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

@ParameterizedTest
@MethodSource("nonRetryableExceptions")
public void testReadWithFlakyStreamNonRetryableException(IOException exception) {
assertThatThrownBy(() -> testRead(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testSeekWithFlakyStreamRetrySucceed(IOException exception) throws Exception {
testSeek(flakyStreamClient(new AtomicInteger(3), exception));
}

@ParameterizedTest
@MethodSource("retryableExceptions")
public void testSeekWithFlakyStreamExhaustedRetries(IOException exception) {
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(5), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

@ParameterizedTest
@MethodSource("nonRetryableExceptions")
public void testSeekWithFlakyStreamNonRetryableException(IOException exception) {
assertThatThrownBy(() -> testSeek(flakyStreamClient(new AtomicInteger(3), exception)))
.isInstanceOf(exception.getClass())
.hasMessage(exception.getMessage());
}

private static Stream<Arguments> retryableExceptions() {
return Stream.of(
Arguments.of(
new SocketTimeoutException("socket timeout exception"),
new SSLException("some ssl exception")));
}

private static Stream<Arguments> nonRetryableExceptions() {
return Stream.of(Arguments.of(new IOException("some generic non-retryable IO exception")));
}

private S3ClientWrapper flakyStreamClient(AtomicInteger counter, IOException failure) {
S3ClientWrapper flakyClient = spy(new S3ClientWrapper(s3Client()));
doAnswer(invocation -> new FlakyInputStream(invocation.callRealMethod(), counter, failure))
.when(flakyClient)
.getObject(any(GetObjectRequest.class), any(ResponseTransformer.class));
return flakyClient;
}

/** Wrapper for S3 client, used to mock the final class DefaultS3Client */
public static class S3ClientWrapper implements S3Client {

private final S3Client delegate;

public S3ClientWrapper(S3Client delegate) {
this.delegate = delegate;
}

@Override
public String serviceName() {
return delegate.serviceName();
}

@Override
public void close() {
delegate.close();
}

@Override
public <ReturnT> ReturnT getObject(
GetObjectRequest getObjectRequest,
ResponseTransformer<GetObjectResponse, ReturnT> responseTransformer)
throws AwsServiceException, SdkClientException {
return delegate.getObject(getObjectRequest, responseTransformer);
}

@Override
public HeadObjectResponse headObject(HeadObjectRequest headObjectRequest)
throws AwsServiceException, SdkClientException {
return delegate.headObject(headObjectRequest);
}

@Override
public PutObjectResponse putObject(PutObjectRequest putObjectRequest, RequestBody requestBody)
throws AwsServiceException, SdkClientException {
return delegate.putObject(putObjectRequest, requestBody);
}

@Override
public CreateBucketResponse createBucket(CreateBucketRequest createBucketRequest)
throws AwsServiceException, SdkClientException {
return delegate.createBucket(createBucketRequest);
}
}

static class FlakyInputStream extends InputStream {
private final ResponseInputStream<GetObjectResponse> delegate;
private final AtomicInteger counter;
private final int round;
private final IOException exception;

FlakyInputStream(Object invocationResponse, AtomicInteger counter, IOException exception) {
this.delegate = (ResponseInputStream<GetObjectResponse>) invocationResponse;
this.counter = counter;
this.round = counter.get();
this.exception = exception;
}

private void checkCounter() throws IOException {
// for every round of n invocations, only the last call succeeds
if (counter.decrementAndGet() == 0) {
counter.set(round);
} else {
throw exception;
}
}

@Override
public int read() throws IOException {
checkCounter();
return delegate.read();
}

@Override
public int read(byte[] b) throws IOException {
checkCounter();
return delegate.read(b);
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
checkCounter();
return delegate.read(b, off, len);
}

@Override
public void close() throws IOException {
delegate.close();
}
}
}
Loading

0 comments on commit c0d73f4

Please sign in to comment.