diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 6b1d879d06183..757b8f7b545b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -70,9 +70,6 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); - // TODO: Dont necessarily write to local - private final File shuffleDir; - private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); // Map containing all registered executors' metadata. @@ -96,8 +93,8 @@ public class ExternalShuffleBlockResolver { final DB db; private final List knownManagers = Arrays.asList( - "org.apache.spark.shuffle.sort.SortShuffleManager", - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + "org.apache.spark.shuffle.sort.SortShuffleManager", + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { @@ -136,9 +133,6 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { executors = Maps.newConcurrentMap(); } - // TODO: Remove local writes - this.shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile(); - this.directoryCleaner = directoryCleaner; } @@ -146,7 +140,6 @@ public int getRegisteredExecutorsSize() { return executors.size(); } - /** Registers a new Executor with all the configuration we need to find its shuffle files. */ public void registerExecutor( String appId, @@ -313,8 +306,8 @@ private ManagedBuffer getSortBasedShuffleBlockData( * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). */ - - public static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index b2b0f3f9796cb..f5196638f9140 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -41,7 +41,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), - UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9); + REGISTER_SHUFFLE_INDEX(8), OPEN_SHUFFLE_PARTITION(9), UPLOAD_SHUFFLE_INDEX(10); private final byte id; @@ -68,8 +68,9 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); case 7: return UploadShufflePartitionStream.decode(buf); - case 8: return UploadShuffleIndexStream.decode(buf); + case 8: return RegisterShuffleIndex.decode(buf); case 9: return OpenShufflePartition.decode(buf); + case 10: return UploadShuffleIndex.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java new file mode 100644 index 0000000000000..27f101171834b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Register shuffle index to the External Shuffle Service. + */ +public class RegisterShuffleIndex extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + + public RegisterShuffleIndex( + String appId, + int shuffleId, + int mapId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShufflePartitionStream) { + UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId; + } + return false; + } + + @Override + protected Type type() { + return Type.REGISTER_SHUFFLE_INDEX; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static RegisterShuffleIndex decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new RegisterShuffleIndex(appId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java similarity index 90% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java index ffa7ee36881c8..374b399621aae 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java @@ -27,12 +27,12 @@ /** * Upload shuffle index request to the External Shuffle Service. */ -public class UploadShuffleIndexStream extends BlockTransferMessage { +public class UploadShuffleIndex extends BlockTransferMessage { public final String appId; public final int shuffleId; public final int mapId; - public UploadShuffleIndexStream( + public UploadShuffleIndex( String appId, int shuffleId, int mapId) { @@ -54,7 +54,7 @@ public boolean equals(Object other) { @Override protected Type type() { - return Type.UPLOAD_SHUFFLE_INDEX_STREAM; + return Type.UPLOAD_SHUFFLE_INDEX; } @Override @@ -83,10 +83,10 @@ public void encode(ByteBuf buf) { buf.writeInt(mapId); } - public static UploadShuffleIndexStream decode(ByteBuf buf) { + public static UploadShuffleIndex decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); - return new UploadShuffleIndexStream(appId, shuffleId, mapId); + return new UploadShuffleIndex(appId, shuffleId, mapId); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java index f0506cc08feb7..ad8f5405192fc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -32,16 +32,19 @@ public class UploadShufflePartitionStream extends BlockTransferMessage { public final int shuffleId; public final int mapId; public final int partitionId; + public final int partitionLength; public UploadShufflePartitionStream( String appId, int shuffleId, int mapId, - int partitionId) { + int partitionId, + int partitionLength) { this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; + this.partitionLength = partitionLength; } @Override @@ -51,7 +54,8 @@ public boolean equals(Object other) { return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId - && partitionId == o.partitionId; + && partitionId == o.partitionId + && partitionLength == o.partitionLength; } return false; } @@ -63,7 +67,7 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(appId, shuffleId, mapId, partitionId); + return Objects.hashCode(appId, shuffleId, mapId, partitionId, partitionLength); } @Override @@ -72,12 +76,14 @@ public String toString() { .add("appId", appId) .add("shuffleId", shuffleId) .add("mapId", mapId) + .add("partitionId", partitionId) + .add("partitionLength", partitionLength) .toString(); } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4; } @Override @@ -86,6 +92,7 @@ public void encode(ByteBuf buf) { buf.writeInt(shuffleId); buf.writeInt(mapId); buf.writeInt(partitionId); + buf.writeInt(partitionLength); } public static UploadShufflePartitionStream decode(ByteBuf buf) { @@ -93,6 +100,8 @@ public static UploadShufflePartitionStream decode(ByteBuf buf) { int shuffleId = buf.readInt(); int mapId = buf.readInt(); int partitionId = buf.readInt(); - return new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId); + int partitionLength = buf.readInt(); + return new UploadShufflePartitionStream( + appId, shuffleId, mapId, partitionId, partitionLength); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index f0f7d5ade6024..06415dba72d34 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -21,7 +21,7 @@ public interface ShuffleMapOutputWriter { ShufflePartitionWriter newPartitionWriter(int partitionId); - void commitAllPartitions(long[] partitionLengths); + void commitAllPartitions(); void abort(Exception exception); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index da35ac76f343f..22a1d3336615c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -2,7 +2,9 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; +import org.apache.spark.network.TransportContext; import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleDataIO; import org.apache.spark.shuffle.api.ShuffleReadSupport; @@ -12,44 +14,41 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { - private static final String SHUFFLE_SERVICE_PORT_CONFIG = "spark.shuffle.service.port"; - private static final String DEFAULT_SHUFFLE_PORT = "7337"; - - private static final SparkEnv sparkEnv = SparkEnv.get(); - private static final BlockManager blockManager = sparkEnv.blockManager(); - - private final SparkConf sparkConf; private final TransportConf conf; - private final SecurityManager securityManager; - private final String hostname; - private final int port; + private final TransportContext context; + private static BlockManager blockManager; + private static SecurityManager securityManager; + private static String hostname; + private static int port; public ExternalShuffleDataIO( SparkConf sparkConf) { - this.sparkConf = sparkConf; this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); - - this.securityManager = sparkEnv.securityManager(); - this.hostname = blockManager.getRandomShuffleHost(); - this.port = blockManager.getRandomShufflePort(); + // Close idle connections + this.context = new TransportContext(conf, new NoOpRpcHandler(), true, true); } @Override public void initialize() { - // TODO: move registerDriver and registerExecutor here + SparkEnv env = SparkEnv.get(); + blockManager = env.blockManager(); + securityManager = env.securityManager(); + hostname = blockManager.getRandomShuffleHost(); + port = blockManager.getRandomShufflePort(); + // TODO: Register Driver and Executor } @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + conf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port); } @Override public ShuffleWriteSupport writeSupport() { return new ExternalShuffleWriteSupport( - conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + conf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java deleted file mode 100644 index fece52b05fce2..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java +++ /dev/null @@ -1,78 +0,0 @@ -package org.apache.spark.shuffle.external; - -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.nio.ByteBuffer; -import java.nio.LongBuffer; - -public class ExternalShuffleIndexWriter { - - private final TransportClientFactory clientFactory; - private final String hostName; - private final int port; - private final String appId; - private final int shuffleId; - private final int mapId; - - public ExternalShuffleIndexWriter( - TransportClientFactory clientFactory, - String hostName, - int port, - String appId, - int shuffleId, - int mapId){ - this.clientFactory = clientFactory; - this.hostName = hostName; - this.port = port; - this.appId = appId; - this.shuffleId = shuffleId; - this.mapId = mapId; - } - - private static final Logger logger = - LoggerFactory.getLogger(ExternalShuffleIndexWriter.class); - - public void write(long[] partitionLengths) { - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully uploaded index"); - } - - @Override - public void onFailure(Throwable e) { - logger.error("Encountered an error uploading index", e); - } - }; - TransportClient client = null; - try { - logger.info("Committing all partitions with a creation of an index file"); - logger.info("Partition Lengths: " + partitionLengths.length); - ByteBuffer streamHeader = new UploadShuffleIndexStream( - appId, shuffleId, mapId).toByteBuffer(); - // Size includes first 0L offset - ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); - LongBuffer longBuffer = byteBuffer.asLongBuffer(); - Long offset = 0L; - longBuffer.put(offset); - for (Long length: partitionLengths) { - offset += length; - longBuffer.put(offset); - } - client = clientFactory.createUnmanagedClient(hostName, port); - client.setClientId(String.format("index-%s-%d-%d", appId, shuffleId, mapId)); - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); - client.uploadStream(new NioManagedBuffer(streamHeader), - new NioManagedBuffer(byteBuffer), callback); - } catch (Exception e) { - client.close(); - logger.error("Encountered error while creating transport client", e); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 58c917bdffdbb..34a11ce2b2a32 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -1,11 +1,16 @@ package org.apache.spark.shuffle.external; +import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.RegisterShuffleIndex; +import org.apache.spark.network.shuffle.protocol.UploadShuffleIndex; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.nio.ByteBuffer; + public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final TransportClientFactory clientFactory; @@ -28,6 +33,22 @@ public ExternalShuffleMapOutputWriter( this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; + + TransportClient client = null; + try { + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer registerShuffleIndex = new RegisterShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-register-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(registerShuffleIndex, 60000); + } catch (Exception e) { + client.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); + } } private static final Logger logger = @@ -46,16 +67,21 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { } @Override - public void commitAllPartitions(long[] partitionLengths) { + public void commitAllPartitions() { + TransportClient client = null; try { - ExternalShuffleIndexWriter externalShuffleIndexWriter = - new ExternalShuffleIndexWriter(clientFactory, - hostName, port, appId, shuffleId, mapId); - externalShuffleIndexWriter.write(partitionLengths); + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer uploadShuffleIndex = new UploadShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-upload-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(uploadShuffleIndex, 60000); } catch (Exception e) { - clientFactory.close(); - logger.error("Encountered error writing index file", e); - throw new RuntimeException(e); // what is standard practice here? + client.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 1c78f186225f4..d9b7d7ac515df 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -65,10 +65,10 @@ public void onFailure(Throwable e) { }; TransportClient client = null; try { - ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, - partitionId).toByteBuffer(); byte[] buf = partitionBuffer.toByteArray(); int size = buf.length; + ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, + partitionId, size).toByteBuffer(); ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); client = clientFactory.createUnmanagedClient(hostName, port); client.setClientId(String.format("data-%s-%d-%d-%d", diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index ddff937d47c25..2687c2a4e2379 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -6,7 +6,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.shuffle.api.ShuffleReadSupport; @@ -20,6 +19,7 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleReadSupport.class); private final TransportConf conf; + private final TransportContext context; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final String hostName; @@ -27,11 +27,13 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { public ExternalShuffleReadSupport( TransportConf conf, + TransportContext context, boolean authEnabled, SecretKeyHolder secretKeyHolder, String hostName, int port) { this.conf = conf; + this.context = context; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostName = hostName; @@ -41,7 +43,6 @@ public ExternalShuffleReadSupport( @Override public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) { // TODO combine this into a function with ExternalShuffleWriteSupport - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 4754c58f136bc..413c2fd63f20a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -6,7 +6,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; @@ -17,34 +16,39 @@ public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { - private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); - private final TransportConf conf; - private final boolean authEnabled; - private final SecretKeyHolder secretKeyHolder; - private final String hostname; - private final int port; + private final TransportConf conf; + private final TransportContext context; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + private final String hostname; + private final int port; - public ExternalShuffleWriteSupport( - TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostname, int port) { - this.conf = conf; - this.authEnabled = authEnabled; - this.secretKeyHolder = secretKeyHolder; - this.hostname = hostname; - this.port = port; - } + public ExternalShuffleWriteSupport( + TransportConf conf, + TransportContext context, + boolean authEnabled, + SecretKeyHolder secretKeyHolder, + String hostname, + int port) { + this.conf = conf; + this.context = context; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + this.hostname = hostname; + this.port = port; +} - @Override - public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); - } - TransportClientFactory clientFactory = context.createClientFactory(bootstraps); - logger.info("Clientfactory: " + clientFactory.toString()); - return new ExternalShuffleMapOutputWriter( - clientFactory, hostname, port, appId, shuffleId, mapId); + @Override + public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + logger.info("Clientfactory: " + clientFactory.toString()); + return new ExternalShuffleMapOutputWriter( + clientFactory, hostname, port, appId, shuffleId, mapId); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 2cdf0c4600aef..823c36d051ddf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -267,7 +267,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio } } } - mapOutputWriter.commitAllPartitions(lengths); + mapOutputWriter.commitAllPartitions(); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4e299034a8934..32be620095110 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -563,7 +563,7 @@ private long[] mergeSpillsWithPluggableWriter( throw e; } } - mapOutputWriter.commitAllPartitions(partitionLengths); + mapOutputWriter.commitAllPartitions(); threwException = false; } catch (Exception e) { try { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 845a3d5f6d6f9..247016584d1f2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -576,6 +576,10 @@ class SparkContext(config: SparkConf) extends Logging { _env.metricsSystem.registerSource(e.executorAllocationManagerSource) } appStatusSource.foreach(_env.metricsSystem.registerSource(_)) + + // Initialize the ShuffleDataIo + _env.shuffleDataIO.foreach(_.initialize()) + // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45aabc05f49b4..c2b56864bf36d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket -import java.util.{Locale, ServiceLoader} +import java.util.Locale import com.google.common.collect.MapMaker import scala.collection.mutable @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.api.ShuffleDataIO import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -65,6 +66,7 @@ class SparkEnv ( val blockManager: BlockManager, val securityManager: SecurityManager, val metricsSystem: MetricsSystem, + val shuffleDataIO: Option[ShuffleDataIO], val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -383,6 +385,9 @@ object SparkEnv extends Logging { ms } + val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf, isDriver) } @@ -402,6 +407,7 @@ object SparkEnv extends Logging { blockManager, securityManager, metricsSystem, + shuffleIoPlugin, memoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a30a501e5d4a1..ae5b1a3c6946a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -118,6 +118,8 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) env.metricsSystem.registerSource(executorSource) env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) + // Initialize the ShuffleDataIo + env.shuffleDataIO.foreach(_.initialize()) } // Whether to load classes in user jars before those in Spark jars diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index ba56da9089a76..eb7ae313918ed 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -119,9 +119,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - shuffleIoPlugin.foreach(_.initialize()) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], conf.getAppId, @@ -129,7 +126,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleIoPlugin.map(_.readSupport())) + SparkEnv.get.shuffleDataIO.map(_.readSupport())) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -141,9 +138,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get - val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - shuffleIoPlugin.foreach(_.initialize()) handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( @@ -155,7 +149,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -164,10 +158,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleIoPlugin.map(_.writeSupport())) + shuffleBlockResolver, other, mapId, context, env.shuffleDataIO.map(_.writeSupport())) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 01cc838474d8e..569c8bd092f37 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -778,7 +778,7 @@ private[spark] class ExternalSorter[K, V, C]( } } } - mapOutputWriter.commitAllPartitions(lengths) + mapOutputWriter.commitAllPartitions() } catch { case e: Exception => util.Utils.tryLogNonFatalError { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0b18aceef92d3..539336cd4fd89 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -696,7 +696,7 @@ public void abort(Exception failureReason) { } @Override - public void commitAllPartitions(long[] partitionlegnths) { + public void commitAllPartitions() { } diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index f6ac1fcc05a1d..3a68fded945b3 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -56,7 +56,7 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { } } - override def commitAllPartitions(partitionLengths: Array[Long]): Unit = {} + override def commitAllPartitions(): Unit = {} override def abort(exception: Exception): Unit = {} } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala index 0ce0d3bec6cdb..883ac10718dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -18,12 +18,10 @@ // scalastyle:off println package org.apache.spark.examples -import java.util.Random - import org.apache.spark.sql.SparkSession /** - * Usage: GroupByShuffleTest [numMappers] [numKVPairs] [KeySize] [numReducers] + * Usage: GroupByShuffleTest */ object GroupByShuffleTest { def main(args: Array[String]) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index e1eac07558831..b9d69f1bc69fb 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -17,13 +17,17 @@ package org.apache.spark.deploy.k8s -import java.io.File +import java.io.{DataOutputStream, File, FileOutputStream} import java.nio.ByteBuffer import java.nio.file.Paths +import java.util import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} +import java.util.function.BiFunction +import com.codahale.metrics._ import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} import scala.collection.JavaConverters._ +import scala.collection.immutable.TreeMap import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService @@ -46,7 +50,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( indexCacheSize: String) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { - ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervals, TimeUnit.SECONDS) @@ -63,14 +66,23 @@ private[spark] class KubernetesExternalShuffleBlockHandler( }) .build(indexCacheLoader) + // TODO: Investigate cleanup if appId is terminated + private val globalPartitionLengths = new ConcurrentHashMap[(String, Int, Int), TreeMap[Int, Long]] + private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") + private final val metricSet: RemoteShuffleMetrics = new RemoteShuffleMetrics() + + private def scanLeft[a, b](xs: Iterable[a])(s: b)(f: (b, a) => b) = + xs.foldLeft(List(s))( (acc, x) => f(acc.head, x) :: acc).reverse + protected override def handleMessage( message: BlockTransferMessage, client: TransportClient, callback: RpcResponseCallback): Unit = { message match { case RegisterDriverParam(appId, appState) => + val responseDelayContext = metricSet.registerDriverRequestLatencyMillis.time() val address = client.getSocketAddress val timeout = appState.heartbeatTimeout logInfo(s"Received registration request from app $appId (remote address $address, " + @@ -84,6 +96,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( throw new RuntimeException(s"Failed to create dir ${driverDir.getAbsolutePath}") } connectedApps.put(appId, appState) + responseDelayContext.stop() callback.onSuccess(ByteBuffer.allocate(0)) case Heartbeat(appId) => @@ -97,9 +110,34 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + s"address $address, appId '$appId').") } + + case RegisterIndexParam(appId, shuffleId, mapId) => + logInfo(s"Received register index param from app $appId") + globalPartitionLengths.putIfAbsent( + (appId, shuffleId, mapId), TreeMap.empty[Int, Long]) + callback.onSuccess(ByteBuffer.allocate(0)) + + case UploadIndexParam(appId, shuffleId, mapId) => + val responseDelayContext = metricSet.writeIndexRequestLatencyMillis.time() + try { + logInfo(s"Received upload index param from app $appId") + val partitionMap = globalPartitionLengths.get((appId, shuffleId, mapId)) + val out = new DataOutputStream( + new FileOutputStream(getFile(appId, shuffleId, mapId, "index"))) + scanLeft(partitionMap.values)(0L)(_ + _).foreach(l => out.writeLong(l)) + out.close() + callback.onSuccess(ByteBuffer.allocate(0)) + } finally { + responseDelayContext.stop() + } + case OpenParam(appId, shuffleId, mapId, partitionId) => logInfo(s"Received open param from app $appId") + val responseDelayContext = metricSet.openBlockRequestLatencyMillis.time() val indexFile = getFile(appId, shuffleId, mapId, "index") + logInfo(s"Map: " + + s"${globalPartitionLengths.get((appId, shuffleId, mapId)).toString()}" + + s"for partitionId: $partitionId") try { val shuffleIndexInformation = shuffleIndexCache.get(indexFile) val shuffleIndexRecord = shuffleIndexInformation.getIndex(partitionId) @@ -111,6 +149,8 @@ private[spark] class KubernetesExternalShuffleBlockHandler( callback.onSuccess(managedBuffer.nioByteBuffer()) } catch { case e: ExecutionException => logError(s"Unable to write index file $indexFile", e) + } finally { + responseDelayContext.stop() } case _ => super.handleMessage(message, client, callback) } @@ -122,20 +162,30 @@ private[spark] class KubernetesExternalShuffleBlockHandler( callback: RpcResponseCallback): StreamCallbackWithID = { header match { case UploadParam( - appId, shuffleId, mapId, partitionId) => - // TODO: Investigate whether we should use the partitionId for Index File creation - logInfo(s"Received upload param from app $appId") - getFileWriterStreamCallback( - appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) - case UploadIndexParam(appId, shuffleId, mapId) => - logInfo(s"Received upload index param from app $appId") - getFileWriterStreamCallback( - appId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) + appId, shuffleId, mapId, partitionId, partitionLength) => + val responseDelayContext = metricSet.writeBlockRequestLatencyMillis.time() + try { + logInfo(s"Received upload param from app $appId") + val lengthMap = TreeMap(partitionId -> partitionLength.toLong) + globalPartitionLengths.merge((appId, shuffleId, mapId), lengthMap, + new BiFunction[TreeMap[Int, Long], TreeMap[Int, Long], TreeMap[Int, Long]]() { + override def apply(t: TreeMap[Int, Long], u: TreeMap[Int, Long]): + TreeMap[Int, Long] = { + t ++ u + } + }) + getFileWriterStreamCallback( + appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + } finally { + responseDelayContext.stop() + } case _ => super.handleStream(header, client, callback) } } + protected override def getAllMetrics: MetricSet = metricSet + private def getFileWriterStreamCallback( appId: String, shuffleId: Int, @@ -169,12 +219,17 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } private object UploadParam { - def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int)] = - Some((u.appId, u.shuffleId, u.mapId, u.partitionId)) + def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId, u.partitionId, u.partitionLength)) } private object UploadIndexParam { - def unapply(u: UploadShuffleIndexStream): Option[(String, Int, Int)] = + def unapply(u: UploadShuffleIndex): Option[(String, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId)) + } + + private object RegisterIndexParam { + def unapply(u: RegisterShuffleIndex): Option[(String, Int, Int)] = Some((u.appId, u.shuffleId, u.mapId)) } @@ -204,6 +259,32 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } } } + private class RemoteShuffleMetrics extends MetricSet { + private val allMetrics = new util.HashMap[String, Metric]() + // Time latency for write request in ms + private val _writeBlockRequestLatencyMillis = new Timer() + def writeBlockRequestLatencyMillis: Timer = _writeBlockRequestLatencyMillis + // Time latency for write index file in ms + private val _writeIndexRequestLatencyMillis = new Timer() + def writeIndexRequestLatencyMillis: Timer = _writeIndexRequestLatencyMillis + // Time latency for read request in ms + private val _openBlockRequestLatencyMillis = new Timer() + def openBlockRequestLatencyMillis: Timer = _openBlockRequestLatencyMillis + // Time latency for executor registration latency in ms + private val _registerDriverRequestLatencyMillis = new Timer() + def registerDriverRequestLatencyMillis: Timer = _registerDriverRequestLatencyMillis + // Block transfer rate in byte per second + private val _blockTransferRateBytes = new Meter() + def blockTransferRateBytes: Meter = _blockTransferRateBytes + + allMetrics.put("writeBlockRequestLatencyMillis", _writeBlockRequestLatencyMillis) + allMetrics.put("writeIndexRequestLatencyMillis", _writeIndexRequestLatencyMillis) + allMetrics.put("openBlockRequestLatencyMillis", _openBlockRequestLatencyMillis) + allMetrics.put("registerDriverRequestLatencyMillis", _registerDriverRequestLatencyMillis) + allMetrics.put("blockTransferRateBytes", _blockTransferRateBytes) + override def getMetrics: util.Map[String, Metric] = allMetrics + } + } /**