Skip to content

Commit

Permalink
Reuse s3 transfer manager
Browse files Browse the repository at this point in the history
  • Loading branch information
hiboyang committed Dec 16, 2021
1 parent 8222f38 commit 761fe2a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.starshuffle;

import com.amazonaws.ClientConfiguration;
import com.amazonaws.client.builder.ExecutorFactory;
import com.amazonaws.event.ProgressEvent;
import com.amazonaws.event.ProgressListener;
import com.amazonaws.regions.Regions;
Expand All @@ -38,6 +39,7 @@
import java.io.*;
import java.net.URI;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

/**
Expand All @@ -53,15 +55,34 @@ public class StarS3ShuffleFileManager implements StarShuffleFileManager {
// https://github.com/apache/hadoop/blob/6c6d1b64d4a7cd5288fcded78043acaf23228f96/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/Constants.java
public static final long DEFAULT_MULTIPART_SIZE = 67108864; // 64M
public static final long DEFAULT_MIN_MULTIPART_THRESHOLD = 134217728; // 128M
public static final String MAX_THREADS = "fs.s3a.threads.max";
public static final int DEFAULT_MAX_THREADS = 10;
public static final String KEEPALIVE_TIME = "fs.s3a.threads.keepalivetime";
public static final int DEFAULT_KEEPALIVE_TIME = 60;

public static final String AWS_REGION = "fs.s3a.endpoint.region";
public final static String DEFAULT_AWS_REGION = Regions.US_WEST_2.getName();
public static final String DEFAULT_AWS_REGION = Regions.US_WEST_2.getName();

private static TransferManager transferManager;
private static Object transferManagerLock = new Object();

private final String awsRegion;
private final int maxThreads;
private final long keepAliveTime;

public StarS3ShuffleFileManager(SparkConf conf) {
Configuration hadoopConf = SparkHadoopUtil.get().newConfiguration(conf);

awsRegion = hadoopConf.get(AWS_REGION, DEFAULT_AWS_REGION);

int threads = conf.getInt(MAX_THREADS, DEFAULT_MAX_THREADS);
if (threads < 2) {
logger.warn(MAX_THREADS + " must be at least 2: forcing to 2.");
threads = 2;
}
maxThreads = threads;

keepAliveTime = conf.getLong(KEEPALIVE_TIME, DEFAULT_KEEPALIVE_TIME);
}

@Override
Expand Down Expand Up @@ -92,7 +113,7 @@ private void writeS3(InputStream inputStream, long size, String s3Url) {
String bucket = bucketAndKey.getBucket();
String key = bucketAndKey.getKey();

TransferManager transferManager = createTransferManager();
TransferManager transferManager = getTransferManager();

ObjectMetadata metadata = new ObjectMetadata();
metadata.setContentType("application/octet-stream");
Expand Down Expand Up @@ -142,27 +163,6 @@ public void progressChanged(ProgressEvent progressEvent) {
}
}

private TransferManager createTransferManager() {
ClientConfiguration clientConfiguration = new ClientConfiguration();
clientConfiguration.setConnectionTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setRequestTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setSocketTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC);

AmazonS3 s3Client = AmazonS3ClientBuilder.standard()
.withRegion(awsRegion)
.withClientConfiguration(clientConfiguration)
.build();

return TransferManagerBuilder.standard()
.withS3Client(s3Client)
.withMinimumUploadPartSize(DEFAULT_MULTIPART_SIZE)
.withMultipartUploadThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
.withMultipartCopyPartSize(DEFAULT_MULTIPART_SIZE)
.withMultipartCopyThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
.build();
}

private InputStream readS3(String s3Url, long offset, long size) {
logger.info("Downloading shuffle file from s3: {}, size: {}", s3Url, size);

Expand All @@ -175,7 +175,7 @@ private InputStream readS3(String s3Url, long offset, long size) {
throw new RuntimeException("Failed to create temp file for downloading shuffle file");
}

TransferManager transferManager = createTransferManager();
TransferManager transferManager = getTransferManager();

GetObjectRequest getObjectRequest = new GetObjectRequest(bucketAndKey.getBucket(), bucketAndKey.getKey())
.withRange(offset, offset + size);
Expand Down Expand Up @@ -225,6 +225,68 @@ public void progressChanged(ProgressEvent progressEvent) {
}
}

private TransferManager getTransferManager() {
synchronized (transferManagerLock) {
if (transferManager != null) {
return transferManager;
}
transferManager = createTransferManager(awsRegion, maxThreads, keepAliveTime);
return transferManager;
}
}

private static TransferManager createTransferManager(String region, int maxThreads, long keepAliveTime) {
ClientConfiguration clientConfiguration = new ClientConfiguration();
clientConfiguration.setConnectionTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setRequestTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setSocketTimeout(S3_PUT_TIMEOUT_MILLISEC);
clientConfiguration.setClientExecutionTimeout(S3_PUT_TIMEOUT_MILLISEC);

ThreadFactory threadFactory = new ThreadFactory() {
private int threadCount = 1;
public Thread newThread(Runnable r) {
Thread thread = new Thread(r);
thread.setName("s3-shuffle-transfer-manager-worker-" + this.threadCount++);
return thread;
}
};
ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(
maxThreads, Integer.MAX_VALUE,
keepAliveTime, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(),
threadFactory);
ExecutorFactory executorFactory = new ExecutorFactory() {
@Override
public ExecutorService newExecutor() {
return threadPoolExecutor;
}
};

AmazonS3 s3Client = AmazonS3ClientBuilder.standard()
.withRegion(region)
.withClientConfiguration(clientConfiguration)
.build();

return TransferManagerBuilder.standard()
.withS3Client(s3Client)
.withMinimumUploadPartSize(DEFAULT_MULTIPART_SIZE)
.withMultipartUploadThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
.withMultipartCopyPartSize(DEFAULT_MULTIPART_SIZE)
.withMultipartCopyThreshold(DEFAULT_MIN_MULTIPART_THRESHOLD)
.withExecutorFactory(executorFactory)
.build();
}

public static void shutdownTransferManager() {
synchronized (transferManagerLock) {
if (transferManager == null) {
return;
}
transferManager.shutdownNow(true);
transferManager = null;
}
}

public static class S3BucketAndKey {
private String bucket;
private String key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
package org.apache.spark.shuffle

import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.{Logging, config}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.starshuffle.{StarBlockStoreClient, StarBypassMergeSortShuffleWriter}
import org.apache.spark.starshuffle.{StarBlockStoreClient, StarBypassMergeSortShuffleWriter, StarS3ShuffleFileManager}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, StarShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.{ExternalSorter, OpenHashSet}
Expand Down Expand Up @@ -116,6 +114,9 @@ class StarShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
/** Shut down this ShuffleManager. */
override def stop(): Unit = {
shuffleBlockResolver.stop()

// TODO use a better way to shutdown TransferManager in StarS3ShuffleFileManager
StarS3ShuffleFileManager.shutdownTransferManager()
}
}

Expand Down

0 comments on commit 761fe2a

Please sign in to comment.