diff --git a/R/README.md b/R/README.md index bb3464ba9955d..810bfc14e977e 100644 --- a/R/README.md +++ b/R/README.md @@ -40,7 +40,7 @@ To set other options like driver memory, executor memory etc. you can pass in th If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example ``` # Set this to where Spark is installed -Sys.setenv(SPARK_HOME="/Users/shivaram/spark") +Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) @@ -51,7 +51,7 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. -Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below. +Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. #### Generating documentation @@ -60,9 +60,9 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. -To run one of them, use `./bin/sparkR `. For example: +To run one of them, use `./bin/spark-submit `. For example: - ./bin/sparkR examples/src/main/r/dataframe.R + ./bin/spark-submit examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -70,7 +70,7 @@ You can also run the unit-tests for SparkR by running (you need to install the [ ./R/run-tests.sh ### Running on YARN -The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run +The `./bin/spark-submit` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf ./bin/spark-submit --master yarn examples/src/main/r/dataframe.R diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 238710d17249a..5320b28bc054c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -43,7 +43,8 @@ /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to - * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * setup Netty Channel pipelines with a + * {@link org.apache.spark.network.server.TransportChannelHandler}. * * There are two communication protocols that the TransportClient provides, control-plane RPCs and * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index 4c8802af7ae67..acc49d968c186 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -28,7 +28,7 @@ /** * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. */ -public final class NettyManagedBuffer extends ManagedBuffer { +public class NettyManagedBuffer extends ManagedBuffer { private final ByteBuf buf; public NettyManagedBuffer(ByteBuf buf) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 29e6a30dc1f67..d322aec28793e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,9 +21,9 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be - * called. + * Callback for streaming data. Stream data will be offered to the + * {@link #onData(String, ByteBuffer)} method as it arrives. Once all the stream data is received, + * {@link #onComplete(String)} will be called. *

* The network library guarantees that a single thread will call these methods at a time, but * different call may be made by different threads. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 1008c67de3491..f179bad1f4b15 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -64,7 +64,7 @@ private static class ClientPool { TransportClient[] clients; Object[] locks; - public ClientPool(int size) { + ClientPool(int size) { clients = new TransportClient[size]; locks = new Object[size]; for (int i = 0; i < size; i++) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 66f5b8b3a59c8..434935a8ef2ad 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -33,7 +33,7 @@ public interface Message extends Encodable { boolean isBodyInFrame(); /** Preceding every serialized Message is its type, which allows us to deserialize it. */ - public static enum Type implements Encodable { + enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), @@ -41,7 +41,7 @@ public static enum Type implements Encodable { private final byte id; - private Type(int id) { + Type(int id) { assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java index 31b15bb17a327..b85171ed6f3d1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -17,8 +17,6 @@ package org.apache.spark.network.protocol; -import org.apache.spark.network.protocol.Message; - /** Messages from the client to the server. */ public interface RequestMessage extends Message { // token interface diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java index 6edffd11cf1e2..194e6d9aa2bd4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -17,8 +17,6 @@ package org.apache.spark.network.protocol; -import org.apache.spark.network.protocol.Message; - /** Messages from the server to the client. */ public interface ResponseMessage extends Message { // token interface diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index e52b526f09c77..7331c2b481fb1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -36,11 +36,11 @@ class SaslMessage extends AbstractMessage { public final String appId; - public SaslMessage(String appId, byte[] message) { + SaslMessage(String appId, byte[] message) { this(appId, Unpooled.wrappedBuffer(message)); } - public SaslMessage(String appId, ByteBuf message) { + SaslMessage(String appId, ByteBuf message) { super(new NettyManagedBuffer(message), true); this.appId = appId; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index ea9e735e0a173..e2222ae08534b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -32,8 +32,8 @@ import org.apache.spark.network.client.TransportClient; /** - * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually - * fetched as chunks by the client. Each registered buffer is one chunk. + * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are + * individually fetched as chunks by the client. Each registered buffer is one chunk. */ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 18a9b7887ec28..f2223379a9d24 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -141,8 +141,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc if (responseHandler.numOutstandingRequests() > 0) { String address = NettyUtils.getRemoteAddress(ctx.channel()); logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); + "requests. Assuming connection is dead; please adjust spark.network.timeout if " + + "this is wrong.", address, requestTimeoutNs / 1000 / 1000); client.timeOut(); ctx.close(); } else if (closeIdleConnections) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java index a2f018373f2a4..e097714bbc6de 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -24,7 +24,7 @@ public enum ByteUnit { TiB ((long) Math.pow(1024L, 4L)), PiB ((long) Math.pow(1024L, 5L)); - private ByteUnit(long multiplier) { + ByteUnit(long multiplier) { this.multiplier = multiplier; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java index 5f20b70678d1e..f15ec8d294258 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -19,8 +19,6 @@ import java.util.NoSuchElementException; -import org.apache.spark.network.util.ConfigProvider; - /** Uses System properties to obtain config values. */ public class SystemPropertyConfigProvider extends ConfigProvider { @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 3f7024a6aa260..bd1830e6abc86 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -205,7 +205,7 @@ private boolean feedInterceptor(ByteBuf buf) throws Exception { return interceptor != null; } - public static interface Interceptor { + public interface Interceptor { /** * Handles data received from the remote end. diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index dd0171d1d1c17..959396bb8c268 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -44,7 +44,7 @@ * Suite which ensures that requests that go without a response for the network timeout period are * failed, and the connection closed. * - * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * In this suite, we use 10 seconds as the connection timeout, with some slack given in the tests, * to ensure stability in different test environments. */ public class RequestTimeoutIntegrationSuite { @@ -61,7 +61,7 @@ public class RequestTimeoutIntegrationSuite { @Before public void setUp() throws Exception { Map configMap = Maps.newHashMap(); - configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + configMap.put("spark.shuffle.io.connectionTimeout", "10s"); conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); defaultManager = new StreamManager() { @@ -118,10 +118,10 @@ public StreamManager getStreamManager() { callback0.latch.await(); assertEquals(responseSize, callback0.successLength); - // Second times out after 2 seconds, with slack. Must be IOException. + // Second times out after 10 seconds, with slack. Must be IOException. TestCallback callback1 = new TestCallback(); client.sendRpc(ByteBuffer.allocate(0), callback1); - callback1.latch.await(4, TimeUnit.SECONDS); + callback1.latch.await(60, TimeUnit.SECONDS); assertNotNull(callback1.failure); assertTrue(callback1.failure instanceof IOException); @@ -223,7 +223,7 @@ public StreamManager getStreamManager() { // not complete yet, but should complete soon assertEquals(-1, callback0.successLength); assertNull(callback0.failure); - callback0.latch.await(2, TimeUnit.SECONDS); + callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); // failed at same time as previous diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index cdce297233f4f..268cb40121754 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -17,7 +17,6 @@ package org.apache.spark.network.sasl; -import java.lang.Override; import java.nio.ByteBuffer; import java.util.concurrent.ConcurrentHashMap; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index f22187a01db02..f8d03b3b9433a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -52,7 +52,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler { final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) + throws IOException { this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 460110d78f15b..ce5c68e85375e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -423,7 +423,9 @@ public static class StoreVersion { public final int major; public final int minor; - @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + @JsonCreator public StoreVersion( + @JsonProperty("major") int major, + @JsonProperty("minor") int minor) { this.major = major; this.minor = minor; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 4bb0498e5d5aa..d81cf869ddb9e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -46,7 +46,7 @@ public class RetryingBlockFetcher { * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any * remaining blocks. */ - public static interface BlockFetchStarter { + public interface BlockFetchStarter { /** * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous * bootstrapping followed by fully asynchronous block fetching. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 21c0ff4136aa8..9af6759f5d5f3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -40,13 +40,13 @@ public abstract class BlockTransferMessage implements Encodable { protected abstract Type type(); /** Preceding every serialized message is its type, which allows us to deserialize it. */ - public static enum Type { + public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5); private final byte id; - private Type(int id) { + Type(int id) { assert id < 128 : "Cannot have more than 128 message types"; this.id = (byte) id; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5322fcd7813a7..5bf99241851e7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -212,7 +212,8 @@ public void onBlockFetchFailure(String blockId, Throwable t) { }; String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 9379412155e88..c2e0b7447fb8b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -113,7 +113,8 @@ public void testBadMessages() { // pass } - ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteBuffer(); + ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], + new byte[2]).toByteBuffer(); try { handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java index 3d1f28bcb911e..a61ce4fb7241d 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java @@ -28,7 +28,7 @@ final class Murmur3_x86_32 { private final int seed; - public Murmur3_x86_32(int seed) { + Murmur3_x86_32(int seed) { this.seed = seed; } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java index 7857bf66a72ad..c8c57381f332f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -87,7 +87,8 @@ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidt * To iterate over the true bits in a BitSet, use the following loop: *

    * 
-   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0;
+   *    i = bs.nextSetBit(i + 1, sizeInWords)) {
    *    // operate on index i here
    *  }
    * 
diff --git a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
index d4c42b38ac224..0dd8fafbf2c82 100644
--- a/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
+++ b/core/src/main/java/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -62,5 +62,6 @@ public final  JavaPairRDD union(JavaPairRDD... rdds) {
   // These methods take separate "first" and "rest" elements to avoid having the same type erasure
   public abstract  JavaRDD union(JavaRDD first, List> rest);
   public abstract JavaDoubleRDD union(JavaDoubleRDD first, List rest);
-  public abstract  JavaPairRDD union(JavaPairRDD first, List> rest);
+  public abstract  JavaPairRDD union(JavaPairRDD first, List>
+    rest);
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
index 150144e0e418c..bf16f791f906a 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
@@ -23,5 +23,5 @@
  *  A function that returns Doubles, and can be used to construct DoubleRDDs.
  */
 public interface DoubleFunction extends Serializable {
-  public double call(T t) throws Exception;
+  double call(T t) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function2.java b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
index 793caaa61ac5a..a975ce3c68192 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
@@ -23,5 +23,5 @@
  * A two-argument function that takes arguments of type T1 and T2 and returns an R.
  */
 public interface Function2 extends Serializable {
-  public R call(T1 v1, T2 v2) throws Exception;
+  R call(T1 v1, T2 v2) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function3.java b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
index b4151c3417df4..6eecfb645a663 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function3.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
@@ -23,5 +23,5 @@
  * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
  */
 public interface Function3 extends Serializable {
-  public R call(T1 v1, T2 v2, T3 v3) throws Exception;
+  R call(T1 v1, T2 v2, T3 v3) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
index 99bf240a17225..2fdfa7184a3bd 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -26,5 +26,5 @@
  * construct PairRDDs.
  */
 public interface PairFunction extends Serializable {
-  public Tuple2 call(T t) throws Exception;
+  Tuple2 call(T t) throws Exception;
 }
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 8757dff36f159..9044bb4f4a44b 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -67,9 +67,9 @@ public class TaskMemoryManager {
 
   /**
    * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
-   * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
-   * size is limited by the maximum amount of data that can be stored in a  long[] array, which is
-   * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
+   * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's
+   * maximum page size is limited by the maximum amount of data that can be stored in a long[]
+   * array, which is (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
    */
   public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
 
@@ -268,8 +268,8 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
       logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
       // there is no enough memory actually, it means the actual free memory is smaller than
       // MemoryManager thought, we should keep the acquired memory.
-      acquiredButNotUsed += acquired;
       synchronized (this) {
+        acquiredButNotUsed += acquired;
         allocatedPages.clear(pageNumber);
       }
       // this could trigger spilling to free some pages.
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 052be54d8c3f9..7a60c3eb35740 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -98,7 +98,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
    */
   private boolean stopping = false;
 
-  public BypassMergeSortShuffleWriter(
+  BypassMergeSortShuffleWriter(
       BlockManager blockManager,
       IndexShuffleBlockResolver shuffleBlockResolver,
       BypassMergeSortShuffleHandle handle,
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 7a114df2d6857..81ee7ab58ab5b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -96,7 +96,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
   @Nullable private MemoryBlock currentPage = null;
   private long pageCursor = -1;
 
-  public ShuffleExternalSorter(
+  ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
@@ -320,7 +320,18 @@ private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      LongArray array = allocateArray(used / 8 * 2);
+      LongArray array;
+      try {
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        if (!inMemSorter.hasSpaceForAnotherRecord()) {
+          logger.error("Unable to grow the pointer array");
+          throw e;
+        }
+        return;
+      }
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
         freeArray(array);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 2381cff61f069..fe79ff0e3052b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -51,7 +51,7 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) {
    */
   private int pos = 0;
 
-  public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+  ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
     this.consumer = consumer;
     assert (initialSize > 0);
     this.array = consumer.allocateArray(initialSize);
@@ -122,7 +122,7 @@ public static final class ShuffleSorterIterator {
     final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
     private int position = 0;
 
-    public ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
+    ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
       this.numRecords = numRecords;
       this.pointerArray = pointerArray;
     }
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
index df9f7b7abe028..865def6b83c53 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -29,7 +29,7 @@ final class SpillInfo {
   final File file;
   final TempShuffleBlockId blockId;
 
-  public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+  SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
     this.partitionLengths = new long[numPartitions];
     this.file = file;
     this.blockId = blockId;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index cd06ce9fb911e..0c5fb883a8326 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -45,7 +45,6 @@
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.shuffle.ShuffleWriter;
@@ -82,7 +81,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
 
   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
   private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
-    public MyByteArrayOutputStream(int size) { super(size); }
+    MyByteArrayOutputStream(int size) { super(size); }
     public byte[] getBuf() { return buf; }
   }
 
@@ -108,7 +107,8 @@ public UnsafeShuffleWriter(
     if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
       throw new IllegalArgumentException(
         "UnsafeShuffleWriter can only be used for shuffles with at most " +
-          SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
+        SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() +
+        " reduce partitions");
     }
     this.blockManager = blockManager;
     this.shuffleBlockResolver = shuffleBlockResolver;
diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java
index 0cf84d5f9b716..9307eb93a5b20 100644
--- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java
+++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java
@@ -28,7 +28,7 @@ public enum TaskSorting {
   DECREASING_RUNTIME("-runtime");
 
   private final Set alternateNames;
-  private TaskSorting(String... names) {
+  TaskSorting(String... names) {
     alternateNames = new HashSet<>();
     for (String n: names) {
       alternateNames.add(n);
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index b55a322a1b413..de36814ecca15 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -689,7 +689,7 @@ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
       offset += keyLength;
       Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
 
-      // --- Update bookkeeping data structures -----------------------------------------------------
+      // --- Update bookkeeping data structures ----------------------------------------------------
       offset = currentPage.getBaseOffset();
       Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
       pageCursor += recordLength;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 9236bd2c04fd9..927b19c4e8038 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -293,7 +293,18 @@ private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      LongArray array = allocateArray(used / 8 * 2);
+      LongArray array;
+      try {
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        if (!inMemSorter.hasSpaceForAnotherRecord()) {
+          logger.error("Unable to grow the pointer array");
+          throw e;
+        }
+        return;
+      }
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
         freeArray(array);
@@ -421,7 +432,7 @@ class SpillableIterator extends UnsafeSorterIterator {
     private boolean loaded = false;
     private int numRecords = 0;
 
-    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+    SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
       this.upstream = inMemIterator;
       this.numRecords = inMemIterator.getNumRecords();
     }
@@ -556,7 +567,7 @@ static class ChainedIterator extends UnsafeSorterIterator {
     private UnsafeSorterIterator current;
     private int numRecords;
 
-    public ChainedIterator(Queue iterators) {
+    ChainedIterator(Queue iterators) {
       assert iterators.size() > 0;
       this.numRecords = 0;
       for (UnsafeSorterIterator iter: iterators) {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d3137f5f31c25..12fb62fb77f0f 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -47,7 +47,8 @@ public RecordPointerAndKeyPrefix newKey() {
   }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) {
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos,
+                                          RecordPointerAndKeyPrefix reuse) {
     reuse.recordPointer = data.get(pos * 2);
     reuse.keyPrefix = data.get(pos * 2 + 1);
     return reuse;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index ceb59352af64b..2b1c860e55952 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -26,7 +26,7 @@ final class UnsafeSorterSpillMerger {
   private int numRecords = 0;
   private final PriorityQueue priorityQueue;
 
-  public UnsafeSorterSpillMerger(
+  UnsafeSorterSpillMerger(
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
       final int numSpills) {
@@ -57,7 +57,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept
       // make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
       // does not return wrong result because hasNext will returns true
       // at least priorityQueue.size() times. If we allow n spillReaders in the
-      // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
+      // priorityQueue, we will have n extra empty records in the result of UnsafeSorterIterator.
       spillReader.loadNext();
       priorityQueue.add(spillReader);
       numRecords += spillReader.getNumRecords();
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html
index 5a7a252231053..a2b3826dd324b 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html
@@ -64,10 +64,10 @@
   
   {{#applications}}
     
-      {{id}}
+      {{id}}
       {{name}}
       {{#attempts}}
-      {{attemptId}}
+      {{attemptId}}
       {{startTime}}
       {{endTime}}
       {{duration}}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
index 609651315405c..ef89a9a86f093 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
@@ -123,28 +123,13 @@ $(document).ready(function() {
         if (app["attempts"].length > 1) {
             hasMultipleAttempts = true;
         }
-
-        var maxAttemptId = null
+        var num = app["attempts"].length;
         for (j in app["attempts"]) {
           var attempt = app["attempts"][j];
-          if (attempt['attemptId'] != null) {
-            if (maxAttemptId == null || attempt['attemptId'] > maxAttemptId) {
-              maxAttemptId = attempt['attemptId']
-            }
-          }
-
           attempt["startTime"] = formatDate(attempt["startTime"]);
           attempt["endTime"] = formatDate(attempt["endTime"]);
           attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]);
-
-          var url = null
-          if (maxAttemptId == null) {
-            url = "history/" + id + "/"
-          } else {
-            url = "history/" + id + "/" + maxAttemptId + "/"
-          }
-
-          var app_clone = {"id" : id, "name" : name, "url" : url, "attempts" : [attempt]};
+          var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]};
           array.push(app_clone);
         }
       }
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 459fab88ce1de..e2c47ceda2e6f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -331,7 +331,7 @@ object SparkEnv extends Logging {
 
     // NB: blockManager is not valid until initialize() is called later.
     val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
-      serializer, conf, memoryManager, mapOutputTracker, shuffleManager,
+      serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
       blockTransferService, securityManager, numUsableCores)
 
     val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 2634d88367669..e5e6a9e4a816c 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -30,7 +30,7 @@ import org.apache.spark.io.CompressionCodec
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}
 import org.apache.spark.util.{ByteBufferInputStream, Utils}
-import org.apache.spark.util.io.ByteArrayChunkOutputStream
+import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
 
 /**
  * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
@@ -107,7 +107,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
       TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
     blocks.zipWithIndex.foreach { case (block, i) =>
       val pieceId = BroadcastBlockId(id, "piece" + i)
-      if (!blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
+      val bytes = new ChunkedByteBuffer(block.duplicate())
+      if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
         throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
       }
     }
@@ -115,10 +116,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
   }
 
   /** Fetch torrent blocks from the driver and/or other executors. */
-  private def readBlocks(): Array[ByteBuffer] = {
+  private def readBlocks(): Array[ChunkedByteBuffer] = {
     // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
     // to the driver, so other executors can pull these chunks from this executor as well.
-    val blocks = new Array[ByteBuffer](numBlocks)
+    val blocks = new Array[ChunkedByteBuffer](numBlocks)
     val bm = SparkEnv.get.blockManager
 
     for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
@@ -182,7 +183,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
         case None =>
           logInfo("Started reading broadcast variable " + id)
           val startTimeMs = System.currentTimeMillis()
-          val blocks = readBlocks()
+          val blocks = readBlocks().flatMap(_.getChunks())
           logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
 
           val obj = TorrentBroadcast.unBlockifyObject[T](
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index a62096d771724..ec6d48485f110 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -524,9 +524,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
         |  --proxy-user NAME           User to impersonate when submitting the application.
         |                              This argument does not work with --principal / --keytab.
         |
-        |  --help, -h                  Show this help message and exit
-        |  --verbose, -v               Print additional debug output
-        |  --version,                  Print the version of current Spark
+        |  --help, -h                  Show this help message and exit.
+        |  --verbose, -v               Print additional debug output.
+        |  --version,                  Print the version of current Spark.
         |
         | Spark standalone with cluster deploy mode only:
         |  --driver-cores NUM          Cores for driver (Default: 1).
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 6327d55fe75c2..3201463b8cc97 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -36,6 +36,7 @@ import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTa
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
 import org.apache.spark.util._
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 /**
  * Spark executor, backed by a threadpool to run tasks.
@@ -297,7 +298,9 @@ private[spark] class Executor(
           } else if (resultSize > maxDirectResultSize) {
             val blockId = TaskResultBlockId(taskId)
             env.blockManager.putBytes(
-              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+              blockId,
+              new ChunkedByteBuffer(serializedDirectResult.duplicate()),
+              StorageLevel.MEMORY_AND_DISK_SER)
             logInfo(
               s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
             ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index cc5e851c29b32..8f83668d79029 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network
 
+import scala.reflect.ClassTag
+
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.storage.{BlockId, StorageLevel}
 
@@ -35,7 +37,11 @@ trait BlockDataManager {
    * Returns true if the block was stored and false if the put operation failed or the block
    * already existed.
    */
-  def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean
+  def putBlockData(
+      blockId: BlockId,
+      data: ManagedBuffer,
+      level: StorageLevel,
+      classTag: ClassTag[_]): Boolean
 
   /**
    * Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 2de0f2033f2ed..e43e3a2de2566 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
 
 import scala.concurrent.{Await, Future, Promise}
 import scala.concurrent.duration.Duration
+import scala.reflect.ClassTag
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
@@ -76,7 +77,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       execId: String,
       blockId: BlockId,
       blockData: ManagedBuffer,
-      level: StorageLevel): Future[Unit]
+      level: StorageLevel,
+      classTag: ClassTag[_]): Future[Unit]
 
   /**
    * A special case of [[fetchBlocks]], as it fetches only one block and is blocking.
@@ -114,7 +116,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       execId: String,
       blockId: BlockId,
       blockData: ManagedBuffer,
-      level: StorageLevel): Unit = {
-    Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
+      level: StorageLevel,
+      classTag: ClassTag[_]): Unit = {
+    val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag)
+    Await.result(future, Duration.Inf)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index c1dbca5db2007..2ed8a00df7023 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -20,6 +20,8 @@ package org.apache.spark.network.netty
 import java.nio.ByteBuffer
 
 import scala.collection.JavaConverters._
+import scala.language.existentials
+import scala.reflect.ClassTag
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.BlockDataManager
@@ -61,12 +63,16 @@ class NettyBlockRpcServer(
         responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
 
       case uploadBlock: UploadBlock =>
-        // StorageLevel is serialized as bytes using our JavaSerializer.
-        val level: StorageLevel =
-          serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+        // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
+        val (level: StorageLevel, classTag: ClassTag[_]) = {
+          serializer
+            .newInstance()
+            .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+            .asInstanceOf[(StorageLevel, ClassTag[_])]
+        }
         val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
         val blockId = BlockId(uploadBlock.blockId)
-        blockManager.putBlockData(blockId, data, level)
+        blockManager.putBlockData(blockId, data, level, classTag)
         responseContext.onSuccess(ByteBuffer.allocate(0))
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index f588a28eed28d..5f3d4532dd866 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
 
 import scala.collection.JavaConverters._
 import scala.concurrent.{Future, Promise}
+import scala.reflect.ClassTag
 
 import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.network._
@@ -118,18 +119,19 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
       execId: String,
       blockId: BlockId,
       blockData: ManagedBuffer,
-      level: StorageLevel): Future[Unit] = {
+      level: StorageLevel,
+      classTag: ClassTag[_]): Future[Unit] = {
     val result = Promise[Unit]()
     val client = clientFactory.createClient(hostname, port)
 
-    // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
-    // using our binary protocol.
-    val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))
+    // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
+    // Everything else is encoded using our binary protocol.
+    val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))
 
     // Convert or copy nio buffer into array in order to serialize it.
     val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())
 
-    client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
+    client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer,
       new RpcResponseCallback {
         override def onSuccess(response: ByteBuffer): Unit = {
           logTrace(s"Successfully uploaded block $blockId")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 8a577c83e10db..f96551c793a14 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -326,7 +326,7 @@ abstract class RDD[T: ClassTag](
     val blockId = RDDBlockId(id, partition.index)
     var readCachedBlock = true
     // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
-    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
+    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
       readCachedBlock = false
       computeOrReadCheckpoint(partition, context)
     }) match {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
index 4cd6cbe189aab..4a304a078d658 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
@@ -29,5 +29,4 @@ sealed trait JobResult
 @DeveloperApi
 case object JobSucceeded extends JobResult
 
-@DeveloperApi
 private[spark] case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 7eb6d53c10950..873f1b56bd18b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -83,7 +83,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
                 return
               }
               val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
-                serializedTaskResult.get)
+                serializedTaskResult.get.toByteBuffer)
               sparkEnv.blockManager.master.removeBlock(blockId)
               (deserializedResult, size)
           }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
index 46fab7a899633..94d11c5be5a49 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
@@ -21,6 +21,7 @@ import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.reflect.ClassTag
 
 import com.google.common.collect.ConcurrentHashMultiset
 
@@ -37,10 +38,14 @@ import org.apache.spark.internal.Logging
  * @param level the block's storage level. This is the requested persistence level, not the
  *              effective storage level of the block (i.e. if this is MEMORY_AND_DISK, then this
  *              does not imply that the block is actually resident in memory).
+ * @param classTag the block's [[ClassTag]], used to select the serializer
  * @param tellMaster whether state changes for this block should be reported to the master. This
  *                   is true for most blocks, but is false for broadcast blocks.
  */
-private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+private[storage] class BlockInfo(
+    val level: StorageLevel,
+    val classTag: ClassTag[_],
+    val tellMaster: Boolean) {
 
   /**
    * The size of the block (in bytes)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3bbdf48104c91..83f8c5c37d136 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -18,31 +18,31 @@
 package org.apache.spark.storage
 
 import java.io._
-import java.nio.{ByteBuffer, MappedByteBuffer}
+import java.nio.ByteBuffer
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.concurrent.duration._
+import scala.reflect.ClassTag
 import scala.util.Random
 import scala.util.control.NonFatal
 
-import sun.nio.ch.DirectBuffer
-
 import org.apache.spark._
 import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
 import org.apache.spark.internal.Logging
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.memory.MemoryManager
 import org.apache.spark.network._
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.ExternalShuffleClient
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
 import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{Serializer, SerializerInstance}
+import org.apache.spark.serializer.{Serializer, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.ShuffleManager
 import org.apache.spark.storage.memory._
 import org.apache.spark.util._
+import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
 
 /* Class for returning a fetched block and associated metrics. */
 private[spark] class BlockResult(
@@ -60,7 +60,7 @@ private[spark] class BlockManager(
     executorId: String,
     rpcEnv: RpcEnv,
     val master: BlockManagerMaster,
-    defaultSerializer: Serializer,
+    serializerManager: SerializerManager,
     val conf: SparkConf,
     memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
@@ -295,8 +295,12 @@ private[spark] class BlockManager(
   /**
    * Put the block locally, using the given storage level.
    */
-  override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Boolean = {
-    putBytes(blockId, data.nioByteBuffer(), level)
+  override def putBlockData(
+      blockId: BlockId,
+      data: ManagedBuffer,
+      level: StorageLevel,
+      classTag: ClassTag[_]): Boolean = {
+    putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag)
   }
 
   /**
@@ -418,7 +422,7 @@ private[spark] class BlockManager(
           val iter: Iterator[Any] = if (level.deserialized) {
             memoryStore.getValues(blockId).get
           } else {
-            dataDeserialize(blockId, memoryStore.getBytes(blockId).get)
+            dataDeserialize(blockId, memoryStore.getBytes(blockId).get)(info.classTag)
           }
           val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
           Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
@@ -426,10 +430,11 @@ private[spark] class BlockManager(
           val iterToReturn: Iterator[Any] = {
             val diskBytes = diskStore.getBytes(blockId)
             if (level.deserialized) {
-              val diskValues = dataDeserialize(blockId, diskBytes)
+              val diskValues = dataDeserialize(blockId, diskBytes)(info.classTag)
               maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
             } else {
-              dataDeserialize(blockId, maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes))
+              val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
+              dataDeserialize(blockId, bytes)(info.classTag)
             }
           }
           val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
@@ -444,7 +449,7 @@ private[spark] class BlockManager(
   /**
    * Get block from the local block manager as serialized bytes.
    */
-  def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
+  def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
     logDebug(s"Getting local block $blockId as bytes")
     // As an optimization for map output fetches, if the block is for a shuffle, return it
     // without acquiring a lock; the disk store never deletes (recent) items so this should work
@@ -453,7 +458,8 @@ private[spark] class BlockManager(
       // TODO: This should gracefully handle case where local block is not available. Currently
       // downstream code will throw an exception.
       Option(
-        shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+        new ChunkedByteBuffer(
+          shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()))
     } else {
       blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) }
     }
@@ -465,7 +471,7 @@ private[spark] class BlockManager(
    * Must be called while holding a read lock on the block.
    * Releases the read lock upon exception; keeps the read lock upon successful return.
    */
-  private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ByteBuffer = {
+  private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = {
     val level = info.level
     logDebug(s"Level for block $blockId is $level")
     // In order, try to read the serialized bytes from memory, then from disk, then fall back to
@@ -502,9 +508,9 @@ private[spark] class BlockManager(
    *
    * This does not acquire a lock on this block in this JVM.
    */
-  def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
+  private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
     getRemoteBytes(blockId).map { data =>
-      new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.limit())
+      new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.size)
     }
   }
 
@@ -521,7 +527,7 @@ private[spark] class BlockManager(
   /**
    * Get block from remote block managers as serialized bytes.
    */
-  def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+  def getRemoteBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
     logDebug(s"Getting remote block $blockId")
     require(blockId != null, "BlockId is null")
     var runningFailureCount = 0
@@ -567,7 +573,7 @@ private[spark] class BlockManager(
       }
 
       if (data != null) {
-        return Some(data)
+        return Some(new ChunkedByteBuffer(data))
       }
       logDebug(s"The value of block $blockId is null")
     }
@@ -633,12 +639,13 @@ private[spark] class BlockManager(
    * @return either a BlockResult if the block was successfully cached, or an iterator if the block
    *         could not be cached.
    */
-  def getOrElseUpdate(
+  def getOrElseUpdate[T](
       blockId: BlockId,
       level: StorageLevel,
-      makeIterator: () => Iterator[Any]): Either[BlockResult, Iterator[Any]] = {
+      classTag: ClassTag[T],
+      makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
     // Initially we hold no locks on this block.
-    doPutIterator(blockId, makeIterator, level, keepReadLock = true) match {
+    doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
       case None =>
         // doPut() didn't hand work back to us, so the block already existed or was successfully
         // stored. Therefore, we now hold a read lock on the block.
@@ -664,13 +671,13 @@ private[spark] class BlockManager(
   /**
    * @return true if the block was stored or false if an error occurred.
    */
-  def putIterator(
+  def putIterator[T: ClassTag](
       blockId: BlockId,
-      values: Iterator[Any],
+      values: Iterator[T],
       level: StorageLevel,
       tellMaster: Boolean = true): Boolean = {
     require(values != null, "Values is null")
-    doPutIterator(blockId, () => values, level, tellMaster) match {
+    doPutIterator(blockId, () => values, level, implicitly[ClassTag[T]], tellMaster) match {
       case None =>
         true
       case Some(iter) =>
@@ -703,13 +710,13 @@ private[spark] class BlockManager(
    *
    * @return true if the block was stored or false if an error occurred.
    */
-  def putBytes(
+  def putBytes[T: ClassTag](
       blockId: BlockId,
-      bytes: ByteBuffer,
+      bytes: ChunkedByteBuffer,
       level: StorageLevel,
       tellMaster: Boolean = true): Boolean = {
     require(bytes != null, "Bytes is null")
-    doPutBytes(blockId, bytes, level, tellMaster)
+    doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)
   }
 
   /**
@@ -723,37 +730,35 @@ private[spark] class BlockManager(
    *                     returns.
    * @return true if the block was already present or if the put succeeded, false otherwise.
    */
-  private def doPutBytes(
+  private def doPutBytes[T](
       blockId: BlockId,
-      bytes: ByteBuffer,
+      bytes: ChunkedByteBuffer,
       level: StorageLevel,
+      classTag: ClassTag[T],
       tellMaster: Boolean = true,
       keepReadLock: Boolean = false): Boolean = {
-    doPut(blockId, level, tellMaster = tellMaster, keepReadLock = keepReadLock) { putBlockInfo =>
+    doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
       val startTimeMs = System.currentTimeMillis
       // Since we're storing bytes, initiate the replication before storing them locally.
       // This is faster as data is already serialized and ready to send.
       val replicationFuture = if (level.replication > 1) {
-        // Duplicate doesn't copy the bytes, but just creates a wrapper
-        val bufferView = bytes.duplicate()
         Future {
           // This is a blocking action and should run in futureExecutionContext which is a cached
           // thread pool
-          replicate(blockId, bufferView, level)
+          replicate(blockId, bytes, level, classTag)
         }(futureExecutionContext)
       } else {
         null
       }
 
-      bytes.rewind()
-      val size = bytes.limit()
+      val size = bytes.size
 
       if (level.useMemory) {
         // Put it in memory first, even if it also has useDisk set to true;
         // We will drop it to disk later if the memory store can't hold it.
         val putSucceeded = if (level.deserialized) {
-          val values = dataDeserialize(blockId, bytes.duplicate())
-          memoryStore.putIterator(blockId, values, level) match {
+          val values = dataDeserialize(blockId, bytes)(classTag)
+          memoryStore.putIterator(blockId, values, level, classTag) match {
             case Right(_) => true
             case Left(iter) =>
               // If putting deserialized values in memory failed, we will put the bytes directly to
@@ -772,14 +777,14 @@ private[spark] class BlockManager(
         diskStore.putBytes(blockId, bytes)
       }
 
-      val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
+      val putBlockStatus = getCurrentBlockStatus(blockId, info)
       val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
       if (blockWasSuccessfullyStored) {
         // Now that the block is in either the memory, externalBlockStore, or disk store,
         // tell the master about it.
-        putBlockInfo.size = size
+        info.size = size
         if (tellMaster) {
-          reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+          reportBlockStatus(blockId, info, putBlockStatus)
         }
         Option(TaskContext.get()).foreach { c =>
           c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
@@ -807,6 +812,7 @@ private[spark] class BlockManager(
   private def doPut[T](
       blockId: BlockId,
       level: StorageLevel,
+      classTag: ClassTag[_],
       tellMaster: Boolean,
       keepReadLock: Boolean)(putBody: BlockInfo => Option[T]): Option[T] = {
 
@@ -814,7 +820,7 @@ private[spark] class BlockManager(
     require(level != null && level.isValid, "StorageLevel is null or invalid")
 
     val putBlockInfo = {
-      val newInfo = new BlockInfo(level, tellMaster)
+      val newInfo = new BlockInfo(level, classTag, tellMaster)
       if (blockInfoManager.lockNewBlockForWriting(blockId, newInfo)) {
         newInfo
       } else {
@@ -867,21 +873,22 @@ private[spark] class BlockManager(
    * @return None if the block was already present or if the put succeeded, or Some(iterator)
    *         if the put failed.
    */
-  private def doPutIterator(
+  private def doPutIterator[T](
       blockId: BlockId,
-      iterator: () => Iterator[Any],
+      iterator: () => Iterator[T],
       level: StorageLevel,
+      classTag: ClassTag[T],
       tellMaster: Boolean = true,
-      keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator] = {
-    doPut(blockId, level, tellMaster = tellMaster, keepReadLock = keepReadLock) { putBlockInfo =>
+      keepReadLock: Boolean = false): Option[PartiallyUnrolledIterator[T]] = {
+    doPut(blockId, level, classTag, tellMaster = tellMaster, keepReadLock = keepReadLock) { info =>
       val startTimeMs = System.currentTimeMillis
-      var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator] = None
+      var iteratorFromFailedMemoryStorePut: Option[PartiallyUnrolledIterator[T]] = None
       // Size of the block in bytes
       var size = 0L
       if (level.useMemory) {
         // Put it in memory first, even if it also has useDisk set to true;
         // We will drop it to disk later if the memory store can't hold it.
-        memoryStore.putIterator(blockId, iterator(), level) match {
+        memoryStore.putIterator(blockId, iterator(), level, classTag) match {
           case Right(s) =>
             size = s
           case Left(iter) =>
@@ -889,7 +896,7 @@ private[spark] class BlockManager(
             if (level.useDisk) {
               logWarning(s"Persisting block $blockId to disk instead.")
               diskStore.put(blockId) { fileOutputStream =>
-                dataSerializeStream(blockId, fileOutputStream, iter)
+                dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
               }
               size = diskStore.getSize(blockId)
             } else {
@@ -898,19 +905,19 @@ private[spark] class BlockManager(
         }
       } else if (level.useDisk) {
         diskStore.put(blockId) { fileOutputStream =>
-          dataSerializeStream(blockId, fileOutputStream, iterator())
+          dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
         }
         size = diskStore.getSize(blockId)
       }
 
-      val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
+      val putBlockStatus = getCurrentBlockStatus(blockId, info)
       val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid
       if (blockWasSuccessfullyStored) {
         // Now that the block is in either the memory, externalBlockStore, or disk store,
         // tell the master about it.
-        putBlockInfo.size = size
+        info.size = size
         if (tellMaster) {
-          reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+          reportBlockStatus(blockId, info, putBlockStatus)
         }
         Option(TaskContext.get()).foreach { c =>
           c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
@@ -918,11 +925,11 @@ private[spark] class BlockManager(
         logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
         if (level.replication > 1) {
           val remoteStartTime = System.currentTimeMillis
-          val bytesToReplicate = doGetLocalBytes(blockId, putBlockInfo)
+          val bytesToReplicate = doGetLocalBytes(blockId, info)
           try {
-            replicate(blockId, bytesToReplicate, level)
+            replicate(blockId, bytesToReplicate, level, classTag)
           } finally {
-            BlockManager.dispose(bytesToReplicate)
+            bytesToReplicate.dispose()
           }
           logDebug("Put block %s remotely took %s"
             .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
@@ -944,29 +951,27 @@ private[spark] class BlockManager(
       blockInfo: BlockInfo,
       blockId: BlockId,
       level: StorageLevel,
-      diskBytes: ByteBuffer): ByteBuffer = {
+      diskBytes: ChunkedByteBuffer): ChunkedByteBuffer = {
     require(!level.deserialized)
     if (level.useMemory) {
       // Synchronize on blockInfo to guard against a race condition where two readers both try to
       // put values read from disk into the MemoryStore.
       blockInfo.synchronized {
         if (memoryStore.contains(blockId)) {
-          BlockManager.dispose(diskBytes)
+          diskBytes.dispose()
           memoryStore.getBytes(blockId).get
         } else {
-          val putSucceeded = memoryStore.putBytes(blockId, diskBytes.limit(), () => {
+          val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, () => {
             // https://issues.apache.org/jira/browse/SPARK-6076
             // If the file size is bigger than the free memory, OOM will happen. So if we
             // cannot put it into MemoryStore, copyForMemory should not be created. That's why
-            // this action is put into a `() => ByteBuffer` and created lazily.
-            val copyForMemory = ByteBuffer.allocate(diskBytes.limit)
-            copyForMemory.put(diskBytes)
+            // this action is put into a `() => ChunkedByteBuffer` and created lazily.
+            diskBytes.copy()
           })
           if (putSucceeded) {
-            BlockManager.dispose(diskBytes)
+            diskBytes.dispose()
             memoryStore.getBytes(blockId).get
           } else {
-            diskBytes.rewind()
             diskBytes
           }
         }
@@ -983,12 +988,13 @@ private[spark] class BlockManager(
    * @return a copy of the iterator. The original iterator passed this method should no longer
    *         be used after this method returns.
    */
-  private def maybeCacheDiskValuesInMemory(
+  private def maybeCacheDiskValuesInMemory[T](
       blockInfo: BlockInfo,
       blockId: BlockId,
       level: StorageLevel,
-      diskIterator: Iterator[Any]): Iterator[Any] = {
+      diskIterator: Iterator[T]): Iterator[T] = {
     require(level.deserialized)
+    val classTag = blockInfo.classTag.asInstanceOf[ClassTag[T]]
     if (level.useMemory) {
       // Synchronize on blockInfo to guard against a race condition where two readers both try to
       // put values read from disk into the MemoryStore.
@@ -997,7 +1003,7 @@ private[spark] class BlockManager(
           // Note: if we had a means to discard the disk iterator, we would do that here.
           memoryStore.getValues(blockId).get
         } else {
-          memoryStore.putIterator(blockId, diskIterator, level) match {
+          memoryStore.putIterator(blockId, diskIterator, level, classTag) match {
             case Left(iter) =>
               // The memory store put() failed, so it returned the iterator back to us:
               iter
@@ -1006,7 +1012,7 @@ private[spark] class BlockManager(
               memoryStore.getValues(blockId).get
           }
         }
-      }
+      }.asInstanceOf[Iterator[T]]
     } else {
       diskIterator
     }
@@ -1032,7 +1038,11 @@ private[spark] class BlockManager(
    * Replicate block to another node. Not that this is a blocking call that returns after
    * the block has been replicated.
    */
-  private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = {
+  private def replicate(
+      blockId: BlockId,
+      data: ChunkedByteBuffer,
+      level: StorageLevel,
+      classTag: ClassTag[_]): Unit = {
     val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
     val numPeersToReplicateTo = level.replication - 1
     val peersForReplication = new ArrayBuffer[BlockManagerId]
@@ -1085,11 +1095,16 @@ private[spark] class BlockManager(
         case Some(peer) =>
           try {
             val onePeerStartTime = System.currentTimeMillis
-            data.rewind()
-            logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
+            logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
             blockTransferService.uploadBlockSync(
-              peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
-            logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+              peer.host,
+              peer.port,
+              peer.executorId,
+              blockId,
+              new NettyManagedBuffer(data.toNetty),
+              tLevel,
+              classTag)
+            logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms"
               .format(System.currentTimeMillis - onePeerStartTime))
             peersReplicatedTo += peer
             peersForReplication -= peer
@@ -1112,7 +1127,7 @@ private[spark] class BlockManager(
       }
     }
     val timeTakeMs = (System.currentTimeMillis - startTime)
-    logDebug(s"Replicating $blockId of ${data.limit()} bytes to " +
+    logDebug(s"Replicating $blockId of ${data.size} bytes to " +
       s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
     if (peersReplicatedTo.size < numPeersToReplicateTo) {
       logWarning(s"Block $blockId replicated to only " +
@@ -1133,9 +1148,9 @@ private[spark] class BlockManager(
    * @return true if the block was stored or false if the block was already stored or an
    *         error occurred.
    */
-  def putSingle(
+  def putSingle[T: ClassTag](
       blockId: BlockId,
-      value: Any,
+      value: T,
       level: StorageLevel,
       tellMaster: Boolean = true): Boolean = {
     putIterator(blockId, Iterator(value), level, tellMaster)
@@ -1152,9 +1167,9 @@ private[spark] class BlockManager(
    *
    * @return the block's new effective StorageLevel.
    */
-  def dropFromMemory(
+  private[storage] def dropFromMemory[T: ClassTag](
       blockId: BlockId,
-      data: () => Either[Array[Any], ByteBuffer]): StorageLevel = {
+      data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
     logInfo(s"Dropping block $blockId from memory")
     val info = blockInfoManager.assertBlockIsLockedForWriting(blockId)
     var blockIsUpdated = false
@@ -1166,7 +1181,10 @@ private[spark] class BlockManager(
       data() match {
         case Left(elements) =>
           diskStore.put(blockId) { fileOutputStream =>
-            dataSerializeStream(blockId, fileOutputStream, elements.toIterator)
+            dataSerializeStream(
+              blockId,
+              fileOutputStream,
+              elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]])
           }
         case Right(bytes) =>
           diskStore.putBytes(blockId, bytes)
@@ -1272,41 +1290,42 @@ private[spark] class BlockManager(
   }
 
   /** Serializes into a stream. */
-  def dataSerializeStream(
+  def dataSerializeStream[T: ClassTag](
       blockId: BlockId,
       outputStream: OutputStream,
-      values: Iterator[Any]): Unit = {
+      values: Iterator[T]): Unit = {
     val byteStream = new BufferedOutputStream(outputStream)
-    val ser = defaultSerializer.newInstance()
+    val ser = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
     ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
   }
 
-  /** Serializes into a byte buffer. */
-  def dataSerialize(blockId: BlockId, values: Iterator[Any]): ByteBuffer = {
-    val byteStream = new ByteBufferOutputStream(4096)
-    dataSerializeStream(blockId, byteStream, values)
-    byteStream.toByteBuffer
+  /** Serializes into a chunked byte buffer. */
+  def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
+    val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4)
+    dataSerializeStream(blockId, byteArrayChunkOutputStream, values)
+    new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap))
   }
 
   /**
    * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
    * the iterator is reached.
    */
-  def dataDeserialize(blockId: BlockId, bytes: ByteBuffer): Iterator[Any] = {
-    bytes.rewind()
-    dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true))
+  def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = {
+    dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true))
   }
 
   /**
    * Deserializes a InputStream into an iterator of values and disposes of it when the end of
    * the iterator is reached.
    */
-  def dataDeserializeStream(blockId: BlockId, inputStream: InputStream): Iterator[Any] = {
+  def dataDeserializeStream[T: ClassTag](
+      blockId: BlockId,
+      inputStream: InputStream): Iterator[T] = {
     val stream = new BufferedInputStream(inputStream)
-    defaultSerializer
+    serializerManager.getSerializer(implicitly[ClassTag[T]])
       .newInstance()
       .deserializeStream(wrapForCompression(blockId, stream))
-      .asIterator
+      .asIterator.asInstanceOf[Iterator[T]]
   }
 
   def stop(): Unit = {
@@ -1325,24 +1344,9 @@ private[spark] class BlockManager(
 }
 
 
-private[spark] object BlockManager extends Logging {
+private[spark] object BlockManager {
   private val ID_GENERATOR = new IdGenerator
 
-  /**
-   * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
-   * might cause errors if one attempts to read from the unmapped buffer, but it's better than
-   * waiting for the GC to find it because that could lead to huge numbers of open files. There's
-   * unfortunately no standard API to do this.
-   */
-  def dispose(buffer: ByteBuffer): Unit = {
-    if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
-      logTrace(s"Unmapping $buffer")
-      if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
-        buffer.asInstanceOf[DirectBuffer].cleaner().clean()
-      }
-    }
-  }
-
   def blockIdsToHosts(
       blockIds: Array[BlockId],
       env: SparkEnv,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
index 5886b9c00b557..12594e6a2bc0c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
@@ -17,12 +17,11 @@
 
 package org.apache.spark.storage
 
-import java.nio.ByteBuffer
-
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer}
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 /**
- * This [[ManagedBuffer]] wraps a [[ByteBuffer]] which was retrieved from the [[BlockManager]]
+ * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]]
  * so that the corresponding block's read lock can be released once this buffer's references
  * are released.
  *
@@ -32,7 +31,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
 private[storage] class BlockManagerManagedBuffer(
     blockManager: BlockManager,
     blockId: BlockId,
-    buf: ByteBuffer) extends NioManagedBuffer(buf) {
+    chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) {
 
   override def retain(): ManagedBuffer = {
     super.retain()
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 5c28357ded6d6..ca23e2391ed02 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -26,6 +26,7 @@ import com.google.common.io.Closeables
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.Logging
 import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 /**
  * Stores BlockManager blocks on disk.
@@ -71,23 +72,18 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e
       finishTime - startTime))
   }
 
-  def putBytes(blockId: BlockId, _bytes: ByteBuffer): Unit = {
-    // So that we do not modify the input offsets !
-    // duplicate does not copy buffer, so inexpensive
-    val bytes = _bytes.duplicate()
+  def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = {
     put(blockId) { fileOutputStream =>
       val channel = fileOutputStream.getChannel
       Utils.tryWithSafeFinally {
-        while (bytes.remaining > 0) {
-          channel.write(bytes)
-        }
+        bytes.writeFully(channel)
       } {
         channel.close()
       }
     }
   }
 
-  def getBytes(blockId: BlockId): ByteBuffer = {
+  def getBytes(blockId: BlockId): ChunkedByteBuffer = {
     val file = diskManager.getFile(blockId.name)
     val channel = new RandomAccessFile(file, "r").getChannel
     Utils.tryWithSafeFinally {
@@ -102,9 +98,9 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e
           }
         }
         buf.flip()
-        buf
+        new ChunkedByteBuffer(buf)
       } else {
-        channel.map(MapMode.READ_ONLY, 0, file.length)
+        new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length))
       }
     } {
       channel.close()
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 43cd15921cc97..199a5fc270a41 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -17,10 +17,15 @@
 
 package org.apache.spark.storage
 
+import java.nio.{ByteBuffer, MappedByteBuffer}
+
 import scala.collection.Map
 import scala.collection.mutable
 
+import sun.nio.ch.DirectBuffer
+
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
 
 /**
  * :: DeveloperApi ::
@@ -222,7 +227,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
 }
 
 /** Helper methods for storage-related objects. */
-private[spark] object StorageUtils {
+private[spark] object StorageUtils extends Logging {
+
+  /**
+   * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
+   * might cause errors if one attempts to read from the unmapped buffer, but it's better than
+   * waiting for the GC to find it because that could lead to huge numbers of open files. There's
+   * unfortunately no standard API to do this.
+   */
+  def dispose(buffer: ByteBuffer): Unit = {
+    if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
+      logTrace(s"Unmapping $buffer")
+      if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
+        buffer.asInstanceOf[DirectBuffer].cleaner().clean()
+      }
+    }
+  }
 
   /**
    * Update the given list of RDDInfo with the given list of storage statuses.
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index a7c1854a41ff7..d370ee912ab31 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.storage.memory
 
-import java.nio.ByteBuffer
 import java.util.LinkedHashMap
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
 
 import org.apache.spark.{SparkConf, TaskContext}
 import org.apache.spark.internal.Logging
@@ -29,8 +29,20 @@ import org.apache.spark.memory.MemoryManager
 import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel}
 import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
 import org.apache.spark.util.collection.SizeTrackingVector
+import org.apache.spark.util.io.ChunkedByteBuffer
 
-private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean)
+private sealed trait MemoryEntry[T] {
+  def size: Long
+  def classTag: ClassTag[T]
+}
+private case class DeserializedMemoryEntry[T](
+    value: Array[T],
+    size: Long,
+    classTag: ClassTag[T]) extends MemoryEntry[T]
+private case class SerializedMemoryEntry[T](
+    buffer: ChunkedByteBuffer,
+    size: Long,
+    classTag: ClassTag[T]) extends MemoryEntry[T]
 
 /**
  * Stores blocks in memory, either as Arrays of deserialized Java objects or as
@@ -45,7 +57,7 @@ private[spark] class MemoryStore(
   // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and
   // acquiring or releasing unroll memory, must be synchronized on `memoryManager`!
 
-  private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true)
+  private val entries = new LinkedHashMap[BlockId, MemoryEntry[_]](32, 0.75f, true)
 
   // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
   // All accesses of this map are assumed to have manually synchronized on `memoryManager`
@@ -91,14 +103,16 @@ private[spark] class MemoryStore(
    *
    * @return true if the put() succeeded, false otherwise.
    */
-  def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): Boolean = {
+  def putBytes[T: ClassTag](
+      blockId: BlockId,
+      size: Long,
+      _bytes: () => ChunkedByteBuffer): Boolean = {
     require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
     if (memoryManager.acquireStorageMemory(blockId, size)) {
       // We acquired enough memory for the block, so go ahead and put it
-      // Work on a duplicate - since the original input might be used elsewhere.
-      val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer]
-      assert(bytes.limit == size)
-      val entry = new MemoryEntry(bytes, size, deserialized = false)
+      val bytes = _bytes()
+      assert(bytes.size == size)
+      val entry = new SerializedMemoryEntry[T](bytes, size, implicitly[ClassTag[T]])
       entries.synchronized {
         entries.put(blockId, entry)
       }
@@ -126,10 +140,11 @@ private[spark] class MemoryStore(
    *         iterator or call `close()` on it in order to free the storage memory consumed by the
    *         partially-unrolled block.
    */
-  private[storage] def putIterator(
+  private[storage] def putIterator[T](
       blockId: BlockId,
-      values: Iterator[Any],
-      level: StorageLevel): Either[PartiallyUnrolledIterator, Long] = {
+      values: Iterator[T],
+      level: StorageLevel,
+      classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
 
     require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
 
@@ -148,7 +163,7 @@ private[spark] class MemoryStore(
     // Keep track of unroll memory used by this particular block / putIterator() operation
     var unrollMemoryUsedByThisBlock = 0L
     // Underlying vector for unrolling the block
-    var vector = new SizeTrackingVector[Any]
+    var vector = new SizeTrackingVector[T]()(classTag)
 
     // Request enough memory to begin unrolling
     keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
@@ -184,10 +199,10 @@ private[spark] class MemoryStore(
       val arrayValues = vector.toArray
       vector = null
       val entry = if (level.deserialized) {
-        new MemoryEntry(arrayValues, SizeEstimator.estimate(arrayValues), deserialized = true)
+        new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
       } else {
-        val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)
-        new MemoryEntry(bytes, bytes.limit, deserialized = false)
+        val bytes = blockManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
+        new SerializedMemoryEntry[T](bytes, bytes.size, classTag)
       }
       val size = entry.size
       def transferUnrollToStorage(amount: Long): Unit = {
@@ -241,27 +256,25 @@ private[spark] class MemoryStore(
     }
   }
 
-  def getBytes(blockId: BlockId): Option[ByteBuffer] = {
-    val entry = entries.synchronized {
-      entries.get(blockId)
-    }
-    if (entry == null) {
-      None
-    } else {
-      require(!entry.deserialized, "should only call getBytes on blocks stored in serialized form")
-      Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
+  def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
+    val entry = entries.synchronized { entries.get(blockId) }
+    entry match {
+      case null => None
+      case e: DeserializedMemoryEntry[_] =>
+        throw new IllegalArgumentException("should only call getBytes on serialized blocks")
+      case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
     }
   }
 
-  def getValues(blockId: BlockId): Option[Iterator[Any]] = {
-    val entry = entries.synchronized {
-      entries.get(blockId)
-    }
-    if (entry == null) {
-      None
-    } else {
-      require(entry.deserialized, "should only call getValues on deserialized blocks")
-      Some(entry.value.asInstanceOf[Array[Any]].iterator)
+  def getValues(blockId: BlockId): Option[Iterator[_]] = {
+    val entry = entries.synchronized { entries.get(blockId) }
+    entry match {
+      case null => None
+      case e: SerializedMemoryEntry[_] =>
+        throw new IllegalArgumentException("should only call getValues on deserialized blocks")
+      case DeserializedMemoryEntry(values, _, _) =>
+        val x = Some(values)
+        x.map(_.iterator)
     }
   }
 
@@ -334,6 +347,24 @@ private[spark] class MemoryStore(
         }
       }
 
+      def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
+        val data = entry match {
+          case DeserializedMemoryEntry(values, _, _) => Left(values)
+          case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
+        }
+        val newEffectiveStorageLevel =
+          blockManager.dropFromMemory(blockId, () => data)(entry.classTag)
+        if (newEffectiveStorageLevel.isValid) {
+          // The block is still present in at least one store, so release the lock
+          // but don't delete the block info
+          blockManager.releaseLock(blockId)
+        } else {
+          // The block isn't present in any store, so delete the block info so that the
+          // block can be stored again
+          blockManager.blockInfoManager.removeBlock(blockId)
+        }
+      }
+
       if (freedMemory >= space) {
         logInfo(s"${selectedBlocks.size} blocks selected for dropping")
         for (blockId <- selectedBlocks) {
@@ -342,21 +373,7 @@ private[spark] class MemoryStore(
           // blocks and removing entries. However the check is still here for
           // future safety.
           if (entry != null) {
-            val data = if (entry.deserialized) {
-              Left(entry.value.asInstanceOf[Array[Any]])
-            } else {
-              Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
-            }
-            val newEffectiveStorageLevel = blockManager.dropFromMemory(blockId, () => data)
-            if (newEffectiveStorageLevel.isValid) {
-              // The block is still present in at least one store, so release the lock
-              // but don't delete the block info
-              blockManager.releaseLock(blockId)
-            } else {
-              // The block isn't present in any store, so delete the block info so that the
-              // block can be stored again
-              blockManager.blockInfoManager.removeBlock(blockId)
-            }
+            dropBlock(blockId, entry)
           }
         }
         freedMemory
@@ -472,16 +489,16 @@ private[spark] class MemoryStore(
  * @param unrolled an iterator for the partially-unrolled values.
  * @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]].
  */
-private[storage] class PartiallyUnrolledIterator(
+private[storage] class PartiallyUnrolledIterator[T](
     memoryStore: MemoryStore,
     unrollMemory: Long,
-    unrolled: Iterator[Any],
-    rest: Iterator[Any])
-  extends Iterator[Any] {
+    unrolled: Iterator[T],
+    rest: Iterator[T])
+  extends Iterator[T] {
 
   private[this] var unrolledIteratorIsConsumed: Boolean = false
-  private[this] var iter: Iterator[Any] = {
-    val completionIterator = CompletionIterator[Any, Iterator[Any]](unrolled, {
+  private[this] var iter: Iterator[T] = {
+    val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, {
       unrolledIteratorIsConsumed = true
       memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
     })
@@ -489,7 +506,7 @@ private[storage] class PartiallyUnrolledIterator(
   }
 
   override def hasNext: Boolean = iter.hasNext
-  override def next(): Any = iter.next()
+  override def next(): T = iter.next()
 
   /**
    * Called to dispose of this iterator and free its memory.
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 31312fb064b15..d9fecc5e3011e 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
 import scala.language.implicitConversions
 import scala.xml.Node
 
-import org.eclipse.jetty.server.{Connector, Request, Server}
+import org.eclipse.jetty.server.{AbstractConnector, Connector, Request, Server}
 import org.eclipse.jetty.server.handler._
 import org.eclipse.jetty.server.nio.SelectChannelConnector
 import org.eclipse.jetty.server.ssl.SslSelectChannelConnector
@@ -271,9 +271,24 @@ private[spark] object JettyUtils extends Logging {
 
       gzipHandlers.foreach(collection.addHandler)
       connectors.foreach(_.setHost(hostName))
+      // As each acceptor and each selector will use one thread, the number of threads should at
+      // least be the number of acceptors and selectors plus 1. (See SPARK-13776)
+      var minThreads = 1
+      connectors.foreach { c =>
+        // Currently we only use "SelectChannelConnector"
+        val connector = c.asInstanceOf[SelectChannelConnector]
+        // Limit the max acceptor number to 8 so that we don't waste a lot of threads
+        connector.setAcceptors(math.min(connector.getAcceptors, 8))
+        // The number of selectors always equals to the number of acceptors
+        minThreads += connector.getAcceptors * 2
+      }
       server.setConnectors(connectors.toArray)
 
       val pool = new QueuedThreadPool
+      if (serverName.nonEmpty) {
+        pool.setName(serverName)
+      }
+      pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
       pool.setDaemon(true)
       server.setThreadPool(pool)
       val errorHandler = new ErrorHandler()
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
index 54de4d4ee8ca7..dce2ac63a664c 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -20,10 +20,10 @@ package org.apache.spark.util
 import java.io.InputStream
 import java.nio.ByteBuffer
 
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.StorageUtils
 
 /**
- * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose()
+ * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose()
  * at the end of the stream (e.g. to close a memory-mapped file).
  */
 private[spark]
@@ -68,12 +68,12 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f
   }
 
   /**
-   * Clean up the buffer, and potentially dispose of it using BlockManager.dispose().
+   * Clean up the buffer, and potentially dispose of it using StorageUtils.dispose().
    */
   private def cleanUp() {
     if (buffer != null) {
       if (dispose) {
-        BlockManager.dispose(buffer)
+        StorageUtils.dispose(buffer)
       }
       buffer = null
     }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index efc2482c74ddf..22d7a4988bb56 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -19,17 +19,13 @@ package org.apache.spark.util.collection
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.annotation.DeveloperApi
-
 /**
- * :: DeveloperApi ::
  * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
  * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
  * space overhead.
  *
  * Under the hood, it uses our OpenHashSet implementation.
  */
-@DeveloperApi
 private[spark]
 class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
     initialCapacity: Int)
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
new file mode 100644
index 0000000000000..c643c4b63c601
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+import java.nio.channels.WritableByteChannel
+
+import com.google.common.primitives.UnsignedBytes
+import io.netty.buffer.{ByteBuf, Unpooled}
+
+import org.apache.spark.network.util.ByteArrayWritableChannel
+import org.apache.spark.storage.StorageUtils
+
+/**
+ * Read-only byte buffer which is physically stored as multiple chunks rather than a single
+ * contiguous array.
+ *
+ * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must be non-empty and have
+ *               position == 0. Ownership of these buffers is transferred to the ChunkedByteBuffer,
+ *               so if these buffers may also be used elsewhere then the caller is responsible for
+ *               copying them as needed.
+ */
+private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
+  require(chunks != null, "chunks must not be null")
+  require(chunks.forall(_.limit() > 0), "chunks must be non-empty")
+  require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
+
+  /**
+   * This size of this buffer, in bytes.
+   */
+  val size: Long = chunks.map(_.limit().asInstanceOf[Long]).sum
+
+  def this(byteBuffer: ByteBuffer) = {
+    this(Array(byteBuffer))
+  }
+
+  /**
+   * Write this buffer to a channel.
+   */
+  def writeFully(channel: WritableByteChannel): Unit = {
+    for (bytes <- getChunks()) {
+      while (bytes.remaining > 0) {
+        channel.write(bytes)
+      }
+    }
+  }
+
+  /**
+   * Wrap this buffer to view it as a Netty ByteBuf.
+   */
+  def toNetty: ByteBuf = {
+    Unpooled.wrappedBuffer(getChunks(): _*)
+  }
+
+  /**
+   * Copy this buffer into a new byte array.
+   *
+   * @throws UnsupportedOperationException if this buffer's size exceeds the maximum array size.
+   */
+  def toArray: Array[Byte] = {
+    if (size >= Integer.MAX_VALUE) {
+      throw new UnsupportedOperationException(
+        s"cannot call toArray because buffer size ($size bytes) exceeds maximum array size")
+    }
+    val byteChannel = new ByteArrayWritableChannel(size.toInt)
+    writeFully(byteChannel)
+    byteChannel.close()
+    byteChannel.getData
+  }
+
+  /**
+   * Copy this buffer into a new ByteBuffer.
+   *
+   * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size.
+   */
+  def toByteBuffer: ByteBuffer = {
+    if (chunks.length == 1) {
+      chunks.head.duplicate()
+    } else {
+      ByteBuffer.wrap(toArray)
+    }
+  }
+
+  /**
+   * Creates an input stream to read data from this ChunkedByteBuffer.
+   *
+   * @param dispose if true, [[dispose()]] will be called at the end of the stream
+   *                in order to close any memory-mapped files which back this buffer.
+   */
+  def toInputStream(dispose: Boolean = false): InputStream = {
+    new ChunkedByteBufferInputStream(this, dispose)
+  }
+
+  /**
+   * Get duplicates of the ByteBuffers backing this ChunkedByteBuffer.
+   */
+  def getChunks(): Array[ByteBuffer] = {
+    chunks.map(_.duplicate())
+  }
+
+  /**
+   * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers.
+   * The new buffer will share no resources with the original buffer.
+   */
+  def copy(): ChunkedByteBuffer = {
+    val copiedChunks = getChunks().map { chunk =>
+      // TODO: accept an allocator in this copy method to integrate with mem. accounting systems
+      val newChunk = ByteBuffer.allocate(chunk.limit())
+      newChunk.put(chunk)
+      newChunk.flip()
+      newChunk
+    }
+    new ChunkedByteBuffer(copiedChunks)
+  }
+
+  /**
+   * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
+   * might cause errors if one attempts to read from the unmapped buffer, but it's better than
+   * waiting for the GC to find it because that could lead to huge numbers of open files. There's
+   * unfortunately no standard API to do this.
+   */
+  def dispose(): Unit = {
+    chunks.foreach(StorageUtils.dispose)
+  }
+}
+
+/**
+ * Reads data from a ChunkedByteBuffer.
+ *
+ * @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream
+ *                in order to close any memory-mapped files which back the buffer.
+ */
+private class ChunkedByteBufferInputStream(
+    var chunkedByteBuffer: ChunkedByteBuffer,
+    dispose: Boolean)
+  extends InputStream {
+
+  private[this] var chunks = chunkedByteBuffer.getChunks().iterator
+  private[this] var currentChunk: ByteBuffer = {
+    if (chunks.hasNext) {
+      chunks.next()
+    } else {
+      null
+    }
+  }
+
+  override def read(): Int = {
+    if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
+      currentChunk = chunks.next()
+    }
+    if (currentChunk != null && currentChunk.hasRemaining) {
+      UnsignedBytes.toInt(currentChunk.get())
+    } else {
+      close()
+      -1
+    }
+  }
+
+  override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+    if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
+      currentChunk = chunks.next()
+    }
+    if (currentChunk != null && currentChunk.hasRemaining) {
+      val amountToGet = math.min(currentChunk.remaining(), length)
+      currentChunk.get(dest, offset, amountToGet)
+      amountToGet
+    } else {
+      close()
+      -1
+    }
+  }
+
+  override def skip(bytes: Long): Long = {
+    if (currentChunk != null) {
+      val amountToSkip = math.min(bytes, currentChunk.remaining).toInt
+      currentChunk.position(currentChunk.position + amountToSkip)
+      if (currentChunk.remaining() == 0) {
+        if (chunks.hasNext) {
+          currentChunk = chunks.next()
+        } else {
+          close()
+        }
+      }
+      amountToSkip
+    } else {
+      0L
+    }
+  }
+
+  override def close(): Unit = {
+    if (chunkedByteBuffer != null && dispose) {
+      chunkedByteBuffer.dispose()
+    }
+    chunkedByteBuffer = null
+    chunks = null
+    currentChunk = null
+  }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index c1036b8fac6b6..0f65554516153 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -54,7 +54,7 @@
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.hadoop.mapreduce.Job;
 import org.junit.After;
-import org.junit.Assert;
+import static org.junit.Assert.*;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -102,19 +102,19 @@ public void sparkContextUnion() {
     JavaRDD s2 = sc.parallelize(strings);
     // Varargs
     JavaRDD sUnion = sc.union(s1, s2);
-    Assert.assertEquals(4, sUnion.count());
+    assertEquals(4, sUnion.count());
     // List
     List> list = new ArrayList<>();
     list.add(s2);
     sUnion = sc.union(s1, list);
-    Assert.assertEquals(4, sUnion.count());
+    assertEquals(4, sUnion.count());
 
     // Union of JavaDoubleRDDs
     List doubles = Arrays.asList(1.0, 2.0);
     JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
     JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
     JavaDoubleRDD dUnion = sc.union(d1, d2);
-    Assert.assertEquals(4, dUnion.count());
+    assertEquals(4, dUnion.count());
 
     // Union of JavaPairRDDs
     List> pairs = new ArrayList<>();
@@ -123,7 +123,7 @@ public void sparkContextUnion() {
     JavaPairRDD p1 = sc.parallelizePairs(pairs);
     JavaPairRDD p2 = sc.parallelizePairs(pairs);
     JavaPairRDD pUnion = sc.union(p1, p2);
-    Assert.assertEquals(4, pUnion.count());
+    assertEquals(4, pUnion.count());
   }
 
   @SuppressWarnings("unchecked")
@@ -135,17 +135,17 @@ public void intersection() {
     JavaRDD s2 = sc.parallelize(ints2);
 
     JavaRDD intersections = s1.intersection(s2);
-    Assert.assertEquals(3, intersections.count());
+    assertEquals(3, intersections.count());
 
     JavaRDD empty = sc.emptyRDD();
     JavaRDD emptyIntersection = empty.intersection(s2);
-    Assert.assertEquals(0, emptyIntersection.count());
+    assertEquals(0, emptyIntersection.count());
 
     List doubles = Arrays.asList(1.0, 2.0);
     JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
     JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
     JavaDoubleRDD dIntersection = d1.intersection(d2);
-    Assert.assertEquals(2, dIntersection.count());
+    assertEquals(2, dIntersection.count());
 
     List> pairs = new ArrayList<>();
     pairs.add(new Tuple2<>(1, 2));
@@ -153,7 +153,7 @@ public void intersection() {
     JavaPairRDD p1 = sc.parallelizePairs(pairs);
     JavaPairRDD p2 = sc.parallelizePairs(pairs);
     JavaPairRDD pIntersection = p1.intersection(p2);
-    Assert.assertEquals(2, pIntersection.count());
+    assertEquals(2, pIntersection.count());
   }
 
   @Test
@@ -162,9 +162,9 @@ public void sample() {
     JavaRDD rdd = sc.parallelize(ints);
     // the seeds here are "magic" to make this work out nicely
     JavaRDD sample20 = rdd.sample(true, 0.2, 8);
-    Assert.assertEquals(2, sample20.count());
+    assertEquals(2, sample20.count());
     JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 2);
-    Assert.assertEquals(2, sample20WithoutReplacement.count());
+    assertEquals(2, sample20WithoutReplacement.count());
   }
 
   @Test
@@ -176,13 +176,13 @@ public void randomSplit() {
     JavaRDD rdd = sc.parallelize(ints);
     JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
     // the splits aren't perfect -- not enough data for them to be -- just check they're about right
-    Assert.assertEquals(3, splits.length);
+    assertEquals(3, splits.length);
     long s0 = splits[0].count();
     long s1 = splits[1].count();
     long s2 = splits[2].count();
-    Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250);
-    Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350);
-    Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570);
+    assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250);
+    assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350);
+    assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570);
   }
 
   @Test
@@ -196,17 +196,17 @@ public void sortByKey() {
 
     // Default comparator
     JavaPairRDD sortedRDD = rdd.sortByKey();
-    Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
+    assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
     List> sortedPairs = sortedRDD.collect();
-    Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
-    Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
+    assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
+    assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
 
     // Custom comparator
     sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false);
-    Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
+    assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
     sortedPairs = sortedRDD.collect();
-    Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
-    Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
+    assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
+    assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
   }
 
   @SuppressWarnings("unchecked")
@@ -235,19 +235,19 @@ public int getPartition(Object key) {
 
     JavaPairRDD repartitioned =
         rdd.repartitionAndSortWithinPartitions(partitioner);
-    Assert.assertTrue(repartitioned.partitioner().isPresent());
-    Assert.assertEquals(repartitioned.partitioner().get(), partitioner);
+    assertTrue(repartitioned.partitioner().isPresent());
+    assertEquals(repartitioned.partitioner().get(), partitioner);
     List>> partitions = repartitioned.glom().collect();
-    Assert.assertEquals(partitions.get(0),
+    assertEquals(partitions.get(0),
         Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6)));
-    Assert.assertEquals(partitions.get(1),
+    assertEquals(partitions.get(1),
         Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8)));
   }
 
   @Test
   public void emptyRDD() {
     JavaRDD rdd = sc.emptyRDD();
-    Assert.assertEquals("Empty RDD shouldn't have any values", 0, rdd.count());
+    assertEquals("Empty RDD shouldn't have any values", 0, rdd.count());
   }
 
   @Test
@@ -260,17 +260,18 @@ public void sortBy() {
     JavaRDD> rdd = sc.parallelize(pairs);
 
     // compare on first value
-    JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() {
+    JavaRDD> sortedRDD =
+        rdd.sortBy(new Function, Integer>() {
       @Override
       public Integer call(Tuple2 t) {
         return t._1();
       }
     }, true, 2);
 
-    Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
+    assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
     List> sortedPairs = sortedRDD.collect();
-    Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
-    Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
+    assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1));
+    assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2));
 
     // compare on second value
     sortedRDD = rdd.sortBy(new Function, Integer>() {
@@ -279,10 +280,10 @@ public Integer call(Tuple2 t) {
         return t._2();
       }
     }, true, 2);
-    Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
+    assertEquals(new Tuple2<>(-1, 1), sortedRDD.first());
     sortedPairs = sortedRDD.collect();
-    Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1));
-    Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2));
+    assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1));
+    assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2));
   }
 
   @Test
@@ -295,7 +296,7 @@ public void call(String s) {
         accum.add(1);
       }
     });
-    Assert.assertEquals(2, accum.value().intValue());
+    assertEquals(2, accum.value().intValue());
   }
 
   @Test
@@ -311,7 +312,7 @@ public void call(Iterator iter) {
         }
       }
     });
-    Assert.assertEquals(2, accum.value().intValue());
+    assertEquals(2, accum.value().intValue());
   }
 
   @Test
@@ -319,7 +320,7 @@ public void toLocalIterator() {
     List correct = Arrays.asList(1, 2, 3, 4);
     JavaRDD rdd = sc.parallelize(correct);
     List result = Lists.newArrayList(rdd.toLocalIterator());
-    Assert.assertEquals(correct, result);
+    assertEquals(correct, result);
   }
 
   @Test
@@ -327,7 +328,7 @@ public void zipWithUniqueId() {
     List dataArray = Arrays.asList(1, 2, 3, 4);
     JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId();
     JavaRDD indexes = zip.values();
-    Assert.assertEquals(4, new HashSet<>(indexes.collect()).size());
+    assertEquals(4, new HashSet<>(indexes.collect()).size());
   }
 
   @Test
@@ -336,7 +337,7 @@ public void zipWithIndex() {
     JavaPairRDD zip = sc.parallelize(dataArray).zipWithIndex();
     JavaRDD indexes = zip.values();
     List correctIndexes = Arrays.asList(0L, 1L, 2L, 3L);
-    Assert.assertEquals(correctIndexes, indexes.collect());
+    assertEquals(correctIndexes, indexes.collect());
   }
 
   @SuppressWarnings("unchecked")
@@ -347,8 +348,8 @@ public void lookup() {
       new Tuple2<>("Oranges", "Fruit"),
       new Tuple2<>("Oranges", "Citrus")
     ));
-    Assert.assertEquals(2, categories.lookup("Oranges").size());
-    Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0)));
+    assertEquals(2, categories.lookup("Oranges").size());
+    assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0)));
   }
 
   @Test
@@ -361,14 +362,14 @@ public Boolean call(Integer x) {
       }
     };
     JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd);
-    Assert.assertEquals(2, oddsAndEvens.count());
-    Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
-    Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
+    assertEquals(2, oddsAndEvens.count());
+    assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
+    assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
 
     oddsAndEvens = rdd.groupBy(isOdd, 1);
-    Assert.assertEquals(2, oddsAndEvens.count());
-    Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
-    Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
+    assertEquals(2, oddsAndEvens.count());
+    assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
+    assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
   }
 
   @Test
@@ -384,14 +385,14 @@ public Boolean call(Tuple2 x) {
       };
     JavaPairRDD pairRDD = rdd.zip(rdd);
     JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd);
-    Assert.assertEquals(2, oddsAndEvens.count());
-    Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
-    Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
+    assertEquals(2, oddsAndEvens.count());
+    assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
+    assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
 
     oddsAndEvens = pairRDD.groupBy(areOdd, 1);
-    Assert.assertEquals(2, oddsAndEvens.count());
-    Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
-    Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
+    assertEquals(2, oddsAndEvens.count());
+    assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0)));  // Evens
+    assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds
   }
 
   @SuppressWarnings("unchecked")
@@ -408,8 +409,8 @@ public String call(Tuple2 x) {
       };
     JavaPairRDD pairRDD = rdd.zip(rdd);
     JavaPairRDD> keyed = pairRDD.keyBy(sumToString);
-    Assert.assertEquals(7, keyed.count());
-    Assert.assertEquals(1, (long) keyed.lookup("2").get(0)._1());
+    assertEquals(7, keyed.count());
+    assertEquals(1, (long) keyed.lookup("2").get(0)._1());
   }
 
   @SuppressWarnings("unchecked")
@@ -426,8 +427,8 @@ public void cogroup() {
     ));
     JavaPairRDD, Iterable>> cogrouped =
         categories.cogroup(prices);
-    Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
-    Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+    assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+    assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
 
     cogrouped.collect();
   }
@@ -451,9 +452,9 @@ public void cogroup3() {
 
     JavaPairRDD, Iterable, Iterable>> cogrouped =
         categories.cogroup(prices, quantities);
-    Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
-    Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
-    Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+    assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+    assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+    assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
 
 
     cogrouped.collect();
@@ -480,12 +481,12 @@ public void cogroup4() {
       new Tuple2<>("Apples", "US")
     ));
 
-    JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped =
-        categories.cogroup(prices, quantities, countries);
-    Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
-    Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
-    Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
-    Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));
+    JavaPairRDD, Iterable, Iterable,
+        Iterable>> cogrouped = categories.cogroup(prices, quantities, countries);
+    assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1()));
+    assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2()));
+    assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3()));
+    assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4()));
 
     cogrouped.collect();
   }
@@ -507,7 +508,7 @@ public void leftOuterJoin() {
     ));
     List>>> joined =
       rdd1.leftOuterJoin(rdd2).collect();
-    Assert.assertEquals(5, joined.size());
+    assertEquals(5, joined.size());
     Tuple2>> firstUnmatched =
       rdd1.leftOuterJoin(rdd2).filter(
         new Function>>, Boolean>() {
@@ -516,7 +517,7 @@ public Boolean call(Tuple2>> tup) {
             return !tup._2()._2().isPresent();
           }
       }).first();
-    Assert.assertEquals(3, firstUnmatched._1().intValue());
+    assertEquals(3, firstUnmatched._1().intValue());
   }
 
   @Test
@@ -530,10 +531,10 @@ public Integer call(Integer a, Integer b) {
     };
 
     int sum = rdd.fold(0, add);
-    Assert.assertEquals(33, sum);
+    assertEquals(33, sum);
 
     sum = rdd.reduce(add);
-    Assert.assertEquals(33, sum);
+    assertEquals(33, sum);
   }
 
   @Test
@@ -547,7 +548,7 @@ public Integer call(Integer a, Integer b) {
     };
     for (int depth = 1; depth <= 10; depth++) {
       int sum = rdd.treeReduce(add, depth);
-      Assert.assertEquals(-5, sum);
+      assertEquals(-5, sum);
     }
   }
 
@@ -562,7 +563,7 @@ public Integer call(Integer a, Integer b) {
     };
     for (int depth = 1; depth <= 10; depth++) {
       int sum = rdd.treeAggregate(0, add, add, depth);
-      Assert.assertEquals(-5, sum);
+      assertEquals(-5, sum);
     }
   }
 
@@ -592,10 +593,10 @@ public Set call(Set a, Set b) {
           return a;
         }
       }).collectAsMap();
-    Assert.assertEquals(3, sets.size());
-    Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1));
-    Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3));
-    Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5));
+    assertEquals(3, sets.size());
+    assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1));
+    assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3));
+    assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5));
   }
 
   @SuppressWarnings("unchecked")
@@ -616,9 +617,9 @@ public Integer call(Integer a, Integer b) {
           return a + b;
         }
     });
-    Assert.assertEquals(1, sums.lookup(1).get(0).intValue());
-    Assert.assertEquals(2, sums.lookup(2).get(0).intValue());
-    Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
+    assertEquals(1, sums.lookup(1).get(0).intValue());
+    assertEquals(2, sums.lookup(2).get(0).intValue());
+    assertEquals(3, sums.lookup(3).get(0).intValue());
   }
 
   @SuppressWarnings("unchecked")
@@ -639,14 +640,14 @@ public Integer call(Integer a, Integer b) {
          return a + b;
         }
     });
-    Assert.assertEquals(1, counts.lookup(1).get(0).intValue());
-    Assert.assertEquals(2, counts.lookup(2).get(0).intValue());
-    Assert.assertEquals(3, counts.lookup(3).get(0).intValue());
+    assertEquals(1, counts.lookup(1).get(0).intValue());
+    assertEquals(2, counts.lookup(2).get(0).intValue());
+    assertEquals(3, counts.lookup(3).get(0).intValue());
 
     Map localCounts = counts.collectAsMap();
-    Assert.assertEquals(1, localCounts.get(1).intValue());
-    Assert.assertEquals(2, localCounts.get(2).intValue());
-    Assert.assertEquals(3, localCounts.get(3).intValue());
+    assertEquals(1, localCounts.get(1).intValue());
+    assertEquals(2, localCounts.get(2).intValue());
+    assertEquals(3, localCounts.get(3).intValue());
 
     localCounts = rdd.reduceByKeyLocally(new Function2() {
       @Override
@@ -654,45 +655,45 @@ public Integer call(Integer a, Integer b) {
         return a + b;
       }
     });
-    Assert.assertEquals(1, localCounts.get(1).intValue());
-    Assert.assertEquals(2, localCounts.get(2).intValue());
-    Assert.assertEquals(3, localCounts.get(3).intValue());
+    assertEquals(1, localCounts.get(1).intValue());
+    assertEquals(2, localCounts.get(2).intValue());
+    assertEquals(3, localCounts.get(3).intValue());
   }
 
   @Test
   public void approximateResults() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
     Map countsByValue = rdd.countByValue();
-    Assert.assertEquals(2, countsByValue.get(1).longValue());
-    Assert.assertEquals(1, countsByValue.get(13).longValue());
+    assertEquals(2, countsByValue.get(1).longValue());
+    assertEquals(1, countsByValue.get(13).longValue());
 
     PartialResult> approx = rdd.countByValueApprox(1);
     Map finalValue = approx.getFinalValue();
-    Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01);
-    Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01);
+    assertEquals(2.0, finalValue.get(1).mean(), 0.01);
+    assertEquals(1.0, finalValue.get(13).mean(), 0.01);
   }
 
   @Test
   public void take() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
-    Assert.assertEquals(1, rdd.first().intValue());
+    assertEquals(1, rdd.first().intValue());
     rdd.take(2);
     rdd.takeSample(false, 2, 42);
   }
 
   @Test
   public void isEmpty() {
-    Assert.assertTrue(sc.emptyRDD().isEmpty());
-    Assert.assertTrue(sc.parallelize(new ArrayList()).isEmpty());
-    Assert.assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty());
-    Assert.assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter(
+    assertTrue(sc.emptyRDD().isEmpty());
+    assertTrue(sc.parallelize(new ArrayList()).isEmpty());
+    assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty());
+    assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter(
         new Function() {
           @Override
           public Boolean call(Integer i) {
             return i < 0;
           }
         }).isEmpty());
-    Assert.assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter(
+    assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter(
         new Function() {
           @Override
           public Boolean call(Integer i) {
@@ -706,35 +707,35 @@ public void cartesian() {
     JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
     JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World"));
     JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD);
-    Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first());
+    assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first());
   }
 
   @Test
   public void javaDoubleRDD() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
     JavaDoubleRDD distinct = rdd.distinct();
-    Assert.assertEquals(5, distinct.count());
+    assertEquals(5, distinct.count());
     JavaDoubleRDD filter = rdd.filter(new Function() {
       @Override
       public Boolean call(Double x) {
         return x > 2.0;
       }
     });
-    Assert.assertEquals(3, filter.count());
+    assertEquals(3, filter.count());
     JavaDoubleRDD union = rdd.union(rdd);
-    Assert.assertEquals(12, union.count());
+    assertEquals(12, union.count());
     union = union.cache();
-    Assert.assertEquals(12, union.count());
+    assertEquals(12, union.count());
 
-    Assert.assertEquals(20, rdd.sum(), 0.01);
+    assertEquals(20, rdd.sum(), 0.01);
     StatCounter stats = rdd.stats();
-    Assert.assertEquals(20, stats.sum(), 0.01);
-    Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
-    Assert.assertEquals(20/6.0, rdd.mean(), 0.01);
-    Assert.assertEquals(6.22222, rdd.variance(), 0.01);
-    Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01);
-    Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
-    Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01);
+    assertEquals(20, stats.sum(), 0.01);
+    assertEquals(20/6.0, rdd.mean(), 0.01);
+    assertEquals(20/6.0, rdd.mean(), 0.01);
+    assertEquals(6.22222, rdd.variance(), 0.01);
+    assertEquals(7.46667, rdd.sampleVariance(), 0.01);
+    assertEquals(2.49444, rdd.stdev(), 0.01);
+    assertEquals(2.73252, rdd.sampleStdev(), 0.01);
 
     rdd.first();
     rdd.take(5);
@@ -747,13 +748,13 @@ public void javaDoubleRDDHistoGram() {
     Tuple2 results = rdd.histogram(2);
     double[] expected_buckets = {1.0, 2.5, 4.0};
     long[] expected_counts = {2, 2};
-    Assert.assertArrayEquals(expected_buckets, results._1(), 0.1);
-    Assert.assertArrayEquals(expected_counts, results._2());
+    assertArrayEquals(expected_buckets, results._1(), 0.1);
+    assertArrayEquals(expected_counts, results._2());
     // Test with provided buckets
     long[] histogram = rdd.histogram(expected_buckets);
-    Assert.assertArrayEquals(expected_counts, histogram);
+    assertArrayEquals(expected_counts, histogram);
     // SPARK-5744
-    Assert.assertArrayEquals(
+    assertArrayEquals(
         new long[] {0},
         sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0}));
   }
@@ -769,42 +770,42 @@ public int compare(Double o1, Double o2) {
   public void max() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
     double max = rdd.max(new DoubleComparator());
-    Assert.assertEquals(4.0, max, 0.001);
+    assertEquals(4.0, max, 0.001);
   }
 
   @Test
   public void min() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
     double max = rdd.min(new DoubleComparator());
-    Assert.assertEquals(1.0, max, 0.001);
+    assertEquals(1.0, max, 0.001);
   }
 
   @Test
   public void naturalMax() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
     double max = rdd.max();
-    Assert.assertEquals(4.0, max, 0.0);
+    assertEquals(4.0, max, 0.0);
   }
 
   @Test
   public void naturalMin() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
     double max = rdd.min();
-    Assert.assertEquals(1.0, max, 0.0);
+    assertEquals(1.0, max, 0.0);
   }
 
   @Test
   public void takeOrdered() {
     JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
-    Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator()));
-    Assert.assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2));
+    assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2, new DoubleComparator()));
+    assertEquals(Arrays.asList(1.0, 2.0), rdd.takeOrdered(2));
   }
 
   @Test
   public void top() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
     List top2 = rdd.top(2);
-    Assert.assertEquals(Arrays.asList(4, 3), top2);
+    assertEquals(Arrays.asList(4, 3), top2);
   }
 
   private static class AddInts implements Function2 {
@@ -818,7 +819,7 @@ public Integer call(Integer a, Integer b) {
   public void reduce() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
     int sum = rdd.reduce(new AddInts());
-    Assert.assertEquals(10, sum);
+    assertEquals(10, sum);
   }
 
   @Test
@@ -830,21 +831,21 @@ public Double call(Double v1, Double v2) {
         return v1 + v2;
       }
     });
-    Assert.assertEquals(10.0, sum, 0.001);
+    assertEquals(10.0, sum, 0.001);
   }
 
   @Test
   public void fold() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
     int sum = rdd.fold(0, new AddInts());
-    Assert.assertEquals(10, sum);
+    assertEquals(10, sum);
   }
 
   @Test
   public void aggregate() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
     int sum = rdd.aggregate(0, new AddInts(), new AddInts());
-    Assert.assertEquals(10, sum);
+    assertEquals(10, sum);
   }
 
   @Test
@@ -884,8 +885,8 @@ public Iterator call(String x) {
         return Arrays.asList(x.split(" ")).iterator();
       }
     });
-    Assert.assertEquals("Hello", words.first());
-    Assert.assertEquals(11, words.count());
+    assertEquals("Hello", words.first());
+    assertEquals(11, words.count());
 
     JavaPairRDD pairsRDD = rdd.flatMapToPair(
       new PairFlatMapFunction() {
@@ -899,8 +900,8 @@ public Iterator> call(String s) {
         }
       }
     );
-    Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first());
-    Assert.assertEquals(11, pairsRDD.count());
+    assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first());
+    assertEquals(11, pairsRDD.count());
 
     JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() {
       @Override
@@ -912,8 +913,8 @@ public Iterator call(String s) {
         return lengths.iterator();
       }
     });
-    Assert.assertEquals(5.0, doubles.first(), 0.01);
-    Assert.assertEquals(11, pairsRDD.count());
+    assertEquals(5.0, doubles.first(), 0.01);
+    assertEquals(11, pairsRDD.count());
   }
 
   @SuppressWarnings("unchecked")
@@ -959,7 +960,7 @@ public Iterator call(Iterator iter) {
           return Collections.singletonList(sum).iterator();
         }
     });
-    Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+    assertEquals("[3, 7]", partitionSums.collect().toString());
   }
 
 
@@ -977,7 +978,7 @@ public Iterator call(Integer index, Iterator iter) {
           return Collections.singletonList(sum).iterator();
         }
     }, false);
-    Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+    assertEquals("[3, 7]", partitionSums.collect().toString());
   }
 
   @Test
@@ -989,9 +990,9 @@ public void getNumPartitions(){
             new Tuple2<>("aa", 2),
             new Tuple2<>("aaa", 3)
     ), 2);
-    Assert.assertEquals(3, rdd1.getNumPartitions());
-    Assert.assertEquals(2, rdd2.getNumPartitions());
-    Assert.assertEquals(2, rdd3.getNumPartitions());
+    assertEquals(3, rdd1.getNumPartitions());
+    assertEquals(2, rdd2.getNumPartitions());
+    assertEquals(2, rdd3.getNumPartitions());
   }
 
   @Test
@@ -1000,18 +1001,18 @@ public void repartition() {
     JavaRDD in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
     JavaRDD repartitioned1 = in1.repartition(4);
     List> result1 = repartitioned1.glom().collect();
-    Assert.assertEquals(4, result1.size());
+    assertEquals(4, result1.size());
     for (List l : result1) {
-      Assert.assertFalse(l.isEmpty());
+      assertFalse(l.isEmpty());
     }
 
     // Growing number of partitions
     JavaRDD in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
     JavaRDD repartitioned2 = in2.repartition(2);
     List> result2 = repartitioned2.glom().collect();
-    Assert.assertEquals(2, result2.size());
+    assertEquals(2, result2.size());
     for (List l: result2) {
-      Assert.assertFalse(l.isEmpty());
+      assertFalse(l.isEmpty());
     }
   }
 
@@ -1020,7 +1021,7 @@ public void repartition() {
   public void persist() {
     JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
     doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
-    Assert.assertEquals(20, doubleRDD.sum(), 0.1);
+    assertEquals(20, doubleRDD.sum(), 0.1);
 
     List> pairs = Arrays.asList(
       new Tuple2<>(1, "a"),
@@ -1029,24 +1030,24 @@ public void persist() {
     );
     JavaPairRDD pairRDD = sc.parallelizePairs(pairs);
     pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
-    Assert.assertEquals("a", pairRDD.first()._2());
+    assertEquals("a", pairRDD.first()._2());
 
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
     rdd = rdd.persist(StorageLevel.DISK_ONLY());
-    Assert.assertEquals(1, rdd.first().intValue());
+    assertEquals(1, rdd.first().intValue());
   }
 
   @Test
   public void iterator() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
     TaskContext context = TaskContext$.MODULE$.empty();
-    Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
+    assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
   }
 
   @Test
   public void glom() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
-    Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
+    assertEquals("[1, 2]", rdd.glom().first().toString());
   }
 
   // File input / output tests are largely adapted from FileSuite:
@@ -1059,11 +1060,11 @@ public void textFiles() throws IOException {
     // Read the plain text file and check it's OK
     File outputFile = new File(outputDir, "part-00000");
     String content = Files.toString(outputFile, StandardCharsets.UTF_8);
-    Assert.assertEquals("1\n2\n3\n4\n", content);
+    assertEquals("1\n2\n3\n4\n", content);
     // Also try reading it in as a text file RDD
     List expected = Arrays.asList("1", "2", "3", "4");
     JavaRDD readRDD = sc.textFile(outputDir);
-    Assert.assertEquals(expected, readRDD.collect());
+    assertEquals(expected, readRDD.collect());
   }
 
   @Test
@@ -1083,7 +1084,7 @@ public void wholeTextFiles() throws Exception {
     List> result = readRDD.collect();
 
     for (Tuple2 res : result) {
-      Assert.assertEquals(res._2(), container.get(new URI(res._1()).getPath()));
+      assertEquals(res._2(), container.get(new URI(res._1()).getPath()));
     }
   }
 
@@ -1096,7 +1097,7 @@ public void textFilesCompressed() throws IOException {
     // Try reading it in as a text file RDD
     List expected = Arrays.asList("1", "2", "3", "4");
     JavaRDD readRDD = sc.textFile(outputDir);
-    Assert.assertEquals(expected, readRDD.collect());
+    assertEquals(expected, readRDD.collect());
   }
 
   @SuppressWarnings("unchecked")
@@ -1125,7 +1126,7 @@ public Tuple2 call(Tuple2 pair) {
         return new Tuple2<>(pair._1().get(), pair._2().toString());
       }
     });
-    Assert.assertEquals(pairs, readRDD.collect());
+    assertEquals(pairs, readRDD.collect());
   }
 
   @Test
@@ -1145,7 +1146,7 @@ public void binaryFiles() throws Exception {
     JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3);
     List> result = readRDD.collect();
     for (Tuple2 res : result) {
-      Assert.assertArrayEquals(content1, res._2().toArray());
+      assertArrayEquals(content1, res._2().toArray());
     }
   }
 
@@ -1174,7 +1175,7 @@ public void call(Tuple2 pair) {
 
     List> result = readRDD.collect();
     for (Tuple2 res : result) {
-      Assert.assertArrayEquals(content1, res._2().toArray());
+      assertArrayEquals(content1, res._2().toArray());
     }
   }
 
@@ -1197,10 +1198,10 @@ public void binaryRecords() throws Exception {
     channel1.close();
 
     JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length);
-    Assert.assertEquals(numOfCopies,readRDD.count());
+    assertEquals(numOfCopies,readRDD.count());
     List result = readRDD.collect();
     for (byte[] res : result) {
-      Assert.assertArrayEquals(content1, res);
+      assertArrayEquals(content1, res);
     }
   }
 
@@ -1224,8 +1225,9 @@ public Tuple2 call(Tuple2 pair) {
         outputDir, IntWritable.class, Text.class,
         org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
 
-    JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class);
-    Assert.assertEquals(pairs.toString(), output.map(new Function, String>() {
+    JavaPairRDD output =
+        sc.sequenceFile(outputDir, IntWritable.class, Text.class);
+    assertEquals(pairs.toString(), output.map(new Function, String>() {
       @Override
       public String call(Tuple2 x) {
         return x.toString();
@@ -1254,7 +1256,7 @@ public Tuple2 call(Tuple2 pair) {
     JavaPairRDD output = sc.newAPIHadoopFile(outputDir,
         org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class,
         IntWritable.class, Text.class, Job.getInstance().getConfiguration());
-    Assert.assertEquals(pairs.toString(), output.map(new Function, String>() {
+    assertEquals(pairs.toString(), output.map(new Function, String>() {
       @Override
       public String call(Tuple2 x) {
         return x.toString();
@@ -1270,7 +1272,7 @@ public void objectFilesOfInts() {
     // Try reading the output back as an object file
     List expected = Arrays.asList(1, 2, 3, 4);
     JavaRDD readRDD = sc.objectFile(outputDir);
-    Assert.assertEquals(expected, readRDD.collect());
+    assertEquals(expected, readRDD.collect());
   }
 
   @SuppressWarnings("unchecked")
@@ -1286,7 +1288,7 @@ public void objectFilesOfComplexTypes() {
     rdd.saveAsObjectFile(outputDir);
     // Try reading the output back as an object file
     JavaRDD> readRDD = sc.objectFile(outputDir);
-    Assert.assertEquals(pairs, readRDD.collect());
+    assertEquals(pairs, readRDD.collect());
   }
 
   @SuppressWarnings("unchecked")
@@ -1309,7 +1311,7 @@ public Tuple2 call(Tuple2 pair) {
 
     JavaPairRDD output = sc.hadoopFile(outputDir,
         SequenceFileInputFormat.class, IntWritable.class, Text.class);
-    Assert.assertEquals(pairs.toString(), output.map(new Function, String>() {
+    assertEquals(pairs.toString(), output.map(new Function, String>() {
       @Override
       public String call(Tuple2 x) {
         return x.toString();
@@ -1339,7 +1341,7 @@ public Tuple2 call(Tuple2 pair) {
     JavaPairRDD output = sc.hadoopFile(outputDir,
         SequenceFileInputFormat.class, IntWritable.class, Text.class);
 
-    Assert.assertEquals(pairs.toString(), output.map(new Function, String>() {
+    assertEquals(pairs.toString(), output.map(new Function, String>() {
       @Override
       public String call(Tuple2 x) {
         return x.toString();
@@ -1373,7 +1375,7 @@ public Iterator call(Iterator i, Iterator s) {
       };
 
     JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn);
-    Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+    assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
   }
 
   @Test
@@ -1387,7 +1389,7 @@ public void call(Integer x) {
         intAccum.add(x);
       }
     });
-    Assert.assertEquals((Integer) 25, intAccum.value());
+    assertEquals((Integer) 25, intAccum.value());
 
     final Accumulator doubleAccum = sc.doubleAccumulator(10.0);
     rdd.foreach(new VoidFunction() {
@@ -1396,7 +1398,7 @@ public void call(Integer x) {
         doubleAccum.add((double) x);
       }
     });
-    Assert.assertEquals((Double) 25.0, doubleAccum.value());
+    assertEquals((Double) 25.0, doubleAccum.value());
 
     // Try a custom accumulator type
     AccumulatorParam floatAccumulatorParam = new AccumulatorParam() {
@@ -1423,11 +1425,11 @@ public void call(Integer x) {
         floatAccum.add((float) x);
       }
     });
-    Assert.assertEquals((Float) 25.0f, floatAccum.value());
+    assertEquals((Float) 25.0f, floatAccum.value());
 
     // Test the setValue method
     floatAccum.setValue(5.0f);
-    Assert.assertEquals((Float) 5.0f, floatAccum.value());
+    assertEquals((Float) 5.0f, floatAccum.value());
   }
 
   @Test
@@ -1439,33 +1441,33 @@ public String call(Integer t) {
         return t.toString();
       }
     }).collect();
-    Assert.assertEquals(new Tuple2<>("1", 1), s.get(0));
-    Assert.assertEquals(new Tuple2<>("2", 2), s.get(1));
+    assertEquals(new Tuple2<>("1", 1), s.get(0));
+    assertEquals(new Tuple2<>("2", 2), s.get(1));
   }
 
   @Test
   public void checkpointAndComputation() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
     sc.setCheckpointDir(tempDir.getAbsolutePath());
-    Assert.assertFalse(rdd.isCheckpointed());
+    assertFalse(rdd.isCheckpointed());
     rdd.checkpoint();
     rdd.count(); // Forces the DAG to cause a checkpoint
-    Assert.assertTrue(rdd.isCheckpointed());
-    Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+    assertTrue(rdd.isCheckpointed());
+    assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
   }
 
   @Test
   public void checkpointAndRestore() {
     JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
     sc.setCheckpointDir(tempDir.getAbsolutePath());
-    Assert.assertFalse(rdd.isCheckpointed());
+    assertFalse(rdd.isCheckpointed());
     rdd.checkpoint();
     rdd.count(); // Forces the DAG to cause a checkpoint
-    Assert.assertTrue(rdd.isCheckpointed());
+    assertTrue(rdd.isCheckpointed());
 
-    Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+    assertTrue(rdd.getCheckpointFile().isPresent());
     JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
-    Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+    assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
   }
 
   @Test
@@ -1484,7 +1486,8 @@ public Integer call(Integer v1) {
       }
     };
 
-    Function2 mergeValueFunction = new Function2() {
+    Function2 mergeValueFunction =
+        new Function2() {
       @Override
       public Integer call(Integer v1, Integer v2) {
         return v1 + v2;
@@ -1495,7 +1498,7 @@ public Integer call(Integer v1, Integer v2) {
         .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction);
     Map results = combinedRDD.collectAsMap();
     ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7);
-    Assert.assertEquals(expected, results);
+    assertEquals(expected, results);
 
     Partitioner defaultPartitioner = Partitioner.defaultPartitioner(
         combinedRDD.rdd(),
@@ -1510,7 +1513,7 @@ public Integer call(Integer v1, Integer v2) {
              false,
              new KryoSerializer(new SparkConf()));
     results = combinedRDD.collectAsMap();
-    Assert.assertEquals(expected, results);
+    assertEquals(expected, results);
   }
 
   @SuppressWarnings("unchecked")
@@ -1531,7 +1534,7 @@ public Tuple2 call(Tuple2 in) {
             return new Tuple2<>(in._2(), in._1());
           }
         });
-    Assert.assertEquals(Arrays.asList(
+    assertEquals(Arrays.asList(
         new Tuple2<>(1, 1),
         new Tuple2<>(0, 2),
         new Tuple2<>(1, 3),
@@ -1553,21 +1556,19 @@ public Tuple2 call(Integer i) {
         });
 
     List[] parts = rdd1.collectPartitions(new int[] {0});
-    Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+    assertEquals(Arrays.asList(1, 2), parts[0]);
 
     parts = rdd1.collectPartitions(new int[] {1, 2});
-    Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
-    Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
+    assertEquals(Arrays.asList(3, 4), parts[0]);
+    assertEquals(Arrays.asList(5, 6, 7), parts[1]);
 
-    Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1),
+    assertEquals(Arrays.asList(new Tuple2<>(1, 1),
                                       new Tuple2<>(2, 0)),
                         rdd2.collectPartitions(new int[] {0})[0]);
 
     List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2});
-    Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1),
-                                      new Tuple2<>(4, 0)),
-                        parts2[0]);
-    Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1),
+    assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]);
+    assertEquals(Arrays.asList(new Tuple2<>(5, 1),
                                       new Tuple2<>(6, 0),
                                       new Tuple2<>(7, 1)),
                         parts2[1]);
@@ -1581,7 +1582,7 @@ public void countApproxDistinct() {
       arrayData.add(i % size);
     }
     JavaRDD simpleRdd = sc.parallelize(arrayData, 10);
-    Assert.assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.05) - size) / (size * 1.0)) <= 0.1);
+    assertTrue(Math.abs((simpleRdd.countApproxDistinct(0.05) - size) / (size * 1.0)) <= 0.1);
   }
 
   @Test
@@ -1599,7 +1600,7 @@ public void countApproxDistinctByKey() {
       double count = resItem._1();
       long resCount = resItem._2();
       double error = Math.abs((resCount - count) / count);
-      Assert.assertTrue(error < 0.1);
+      assertTrue(error < 0.1);
     }
 
   }
@@ -1629,7 +1630,7 @@ public void collectAsMapAndSerialize() throws Exception {
     new ObjectOutputStream(bytes).writeObject(map);
     Map deserializedMap = (Map)
         new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())).readObject();
-    Assert.assertEquals(1, deserializedMap.get("foo").intValue());
+    assertEquals(1, deserializedMap.get("foo").intValue());
   }
 
   @Test
@@ -1648,14 +1649,14 @@ public Tuple2 call(Integer i) {
     fractions.put(1, 1.0);
     JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L);
     Map wrCounts = wr.countByKey();
-    Assert.assertEquals(2, wrCounts.size());
-    Assert.assertTrue(wrCounts.get(0) > 0);
-    Assert.assertTrue(wrCounts.get(1) > 0);
+    assertEquals(2, wrCounts.size());
+    assertTrue(wrCounts.get(0) > 0);
+    assertTrue(wrCounts.get(1) > 0);
     JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L);
     Map worCounts = wor.countByKey();
-    Assert.assertEquals(2, worCounts.size());
-    Assert.assertTrue(worCounts.get(0) > 0);
-    Assert.assertTrue(worCounts.get(1) > 0);
+    assertEquals(2, worCounts.size());
+    assertTrue(worCounts.get(0) > 0);
+    assertTrue(worCounts.get(1) > 0);
   }
 
   @Test
@@ -1674,14 +1675,14 @@ public Tuple2 call(Integer i) {
     fractions.put(1, 1.0);
     JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L);
     Map wrExactCounts = wrExact.countByKey();
-    Assert.assertEquals(2, wrExactCounts.size());
-    Assert.assertTrue(wrExactCounts.get(0) == 2);
-    Assert.assertTrue(wrExactCounts.get(1) == 4);
+    assertEquals(2, wrExactCounts.size());
+    assertTrue(wrExactCounts.get(0) == 2);
+    assertTrue(wrExactCounts.get(1) == 4);
     JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L);
     Map worExactCounts = worExact.countByKey();
-    Assert.assertEquals(2, worExactCounts.size());
-    Assert.assertTrue(worExactCounts.get(0) == 2);
-    Assert.assertTrue(worExactCounts.get(1) == 4);
+    assertEquals(2, worExactCounts.size());
+    assertTrue(worExactCounts.get(0) == 2);
+    assertTrue(worExactCounts.get(1) == 4);
   }
 
   private static class SomeCustomClass implements Serializable {
@@ -1697,8 +1698,9 @@ public void collectUnderlyingScalaRDD() {
       data.add(new SomeCustomClass());
     }
     JavaRDD rdd = sc.parallelize(data);
-    SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
-    Assert.assertEquals(data.size(), collected.length);
+    SomeCustomClass[] collected =
+      (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
+    assertEquals(data.size(), collected.length);
   }
 
   private static final class BuggyMapFunction implements Function {
@@ -1715,10 +1717,10 @@ public void collectAsync() throws Exception {
     JavaRDD rdd = sc.parallelize(data, 1);
     JavaFutureAction> future = rdd.collectAsync();
     List result = future.get();
-    Assert.assertEquals(data, result);
-    Assert.assertFalse(future.isCancelled());
-    Assert.assertTrue(future.isDone());
-    Assert.assertEquals(1, future.jobIds().size());
+    assertEquals(data, result);
+    assertFalse(future.isCancelled());
+    assertTrue(future.isDone());
+    assertEquals(1, future.jobIds().size());
   }
 
   @Test
@@ -1727,11 +1729,11 @@ public void takeAsync() throws Exception {
     JavaRDD rdd = sc.parallelize(data, 1);
     JavaFutureAction> future = rdd.takeAsync(1);
     List result = future.get();
-    Assert.assertEquals(1, result.size());
-    Assert.assertEquals((Integer) 1, result.get(0));
-    Assert.assertFalse(future.isCancelled());
-    Assert.assertTrue(future.isDone());
-    Assert.assertEquals(1, future.jobIds().size());
+    assertEquals(1, result.size());
+    assertEquals((Integer) 1, result.get(0));
+    assertFalse(future.isCancelled());
+    assertTrue(future.isDone());
+    assertEquals(1, future.jobIds().size());
   }
 
   @Test
@@ -1747,9 +1749,9 @@ public void call(Integer integer) {
         }
     );
     future.get();
-    Assert.assertFalse(future.isCancelled());
-    Assert.assertTrue(future.isDone());
-    Assert.assertEquals(1, future.jobIds().size());
+    assertFalse(future.isCancelled());
+    assertTrue(future.isDone());
+    assertEquals(1, future.jobIds().size());
   }
 
   @Test
@@ -1758,10 +1760,10 @@ public void countAsync() throws Exception {
     JavaRDD rdd = sc.parallelize(data, 1);
     JavaFutureAction future = rdd.countAsync();
     long count = future.get();
-    Assert.assertEquals(data.size(), count);
-    Assert.assertFalse(future.isCancelled());
-    Assert.assertTrue(future.isDone());
-    Assert.assertEquals(1, future.jobIds().size());
+    assertEquals(data.size(), count);
+    assertFalse(future.isCancelled());
+    assertTrue(future.isDone());
+    assertEquals(1, future.jobIds().size());
   }
 
   @Test
@@ -1775,11 +1777,11 @@ public void call(Integer integer) throws InterruptedException {
       }
     });
     future.cancel(true);
-    Assert.assertTrue(future.isCancelled());
-    Assert.assertTrue(future.isDone());
+    assertTrue(future.isCancelled());
+    assertTrue(future.isDone());
     try {
       future.get(2000, TimeUnit.MILLISECONDS);
-      Assert.fail("Expected future.get() for cancelled job to throw CancellationException");
+      fail("Expected future.get() for cancelled job to throw CancellationException");
     } catch (CancellationException ignored) {
       // pass
     }
@@ -1792,11 +1794,11 @@ public void testAsyncActionErrorWrapping() throws Exception {
     JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync();
     try {
       future.get(2, TimeUnit.SECONDS);
-      Assert.fail("Expected future.get() for failed job to throw ExcecutionException");
+      fail("Expected future.get() for failed job to throw ExcecutionException");
     } catch (ExecutionException ee) {
-      Assert.assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
+      assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
     }
-    Assert.assertTrue(future.isDone());
+    assertTrue(future.isDone());
   }
 
   static class Class1 {}
@@ -1806,7 +1808,7 @@ static class Class2 {}
   public void testRegisterKryoClasses() {
     SparkConf conf = new SparkConf();
     conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class });
-    Assert.assertEquals(
+    assertEquals(
         Class1.class.getName() + "," + Class2.class.getName(),
         conf.get("spark.kryo.classesToRegister"));
   }
@@ -1814,13 +1816,13 @@ public void testRegisterKryoClasses() {
   @Test
   public void testGetPersistentRDDs() {
     java.util.Map> cachedRddsMap = sc.getPersistentRDDs();
-    Assert.assertTrue(cachedRddsMap.isEmpty());
+    assertTrue(cachedRddsMap.isEmpty());
     JavaRDD rdd1 = sc.parallelize(Arrays.asList("a", "b")).setName("RDD1").cache();
     JavaRDD rdd2 = sc.parallelize(Arrays.asList("c", "d")).setName("RDD2").cache();
     cachedRddsMap = sc.getPersistentRDDs();
-    Assert.assertEquals(2, cachedRddsMap.size());
-    Assert.assertEquals("RDD1", cachedRddsMap.get(0).name());
-    Assert.assertEquals("RDD2", cachedRddsMap.get(1).name());
+    assertEquals(2, cachedRddsMap.size());
+    assertEquals("RDD1", cachedRddsMap.get(0).name());
+    assertEquals("RDD2", cachedRddsMap.get(1).name());
   }
 
 }
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 9aab2265c9892..6667179b9d30c 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -92,9 +92,11 @@ public void setup() {
     spillFilesCreated.clear();
     MockitoAnnotations.initMocks(this);
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
-    when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
+    when(diskBlockManager.createTempLocalBlock()).thenAnswer(
+        new Answer>() {
       @Override
-      public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
+      public Tuple2 answer(InvocationOnMock invocationOnMock)
+          throws Throwable {
         TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
         File file = File.createTempFile("spillFile", ".spill", tempDir);
         spillFilesCreated.add(file);
@@ -544,7 +546,8 @@ public void failureToGrow() {
 
   @Test
   public void spillInIterator() throws IOException {
-    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
+    BytesToBytesMap map =
+        new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
     try {
       int i;
       for (i = 0; i < 1024; i++) {
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a79ed58133f1b..db50e551f256e 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -103,9 +103,11 @@ public void setUp() {
     taskContext = mock(TaskContext.class);
     when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
-    when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() {
+    when(diskBlockManager.createTempLocalBlock()).thenAnswer(
+        new Answer>() {
       @Override
-      public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable {
+      public Tuple2 answer(InvocationOnMock invocationOnMock)
+          throws Throwable {
         TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
         File file = File.createTempFile("spillFile", ".spill", tempDir);
         spillFilesCreated.add(file);
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 483319434d00c..f90214fffd396 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -108,8 +108,8 @@ public int compare(long prefix1, long prefix2) {
         return (int) prefix1 - (int) prefix2;
       }
     };
-    UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator,
-      prefixComparator, dataToSort.length);
+    UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager,
+      recordComparator, prefixComparator, dataToSort.length);
     // Given a page of records, insert those records into the sorter one-by-one:
     position = dataPage.getBaseOffset();
     for (int i = 0; i < dataToSort.length; i++) {
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 1c3f2bc315ddc..2732cd674992d 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.Matchers
 import org.scalatest.time.{Millis, Span}
 
 import org.apache.spark.storage.{RDDBlockId, StorageLevel}
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 class NotSerializableClass
 class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
@@ -196,8 +197,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
     blockManager.master.getLocations(blockId).foreach { cmId =>
       val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
         blockId.toString)
-      val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
-        .asInstanceOf[Iterator[Int]].toList
+      val deserialized = blockManager.dataDeserialize[Int](blockId,
+        new ChunkedByteBuffer(bytes.nioByteBuffer())).toList
       assert(deserialized === (1 to 100).toList)
     }
   }
@@ -222,7 +223,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
     val numPartitions = 10
     val conf = new SparkConf()
       .set("spark.storage.unrollMemoryThreshold", "1024")
-      .set("spark.testing.memory", (size * numPartitions).toString)
+      .set("spark.testing.memory", size.toString)
     sc = new SparkContext(clusterUrl, "test", conf)
     val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY)
     assert(data.count() === size)
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
new file mode 100644
index 0000000000000..aab70e7431e07
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.io
+
+import java.nio.ByteBuffer
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.util.ByteArrayWritableChannel
+import org.apache.spark.util.io.ChunkedByteBuffer
+
+class ChunkedByteBufferSuite extends SparkFunSuite {
+
+  test("no chunks") {
+    val emptyChunkedByteBuffer = new ChunkedByteBuffer(Array.empty[ByteBuffer])
+    assert(emptyChunkedByteBuffer.size === 0)
+    assert(emptyChunkedByteBuffer.getChunks().isEmpty)
+    assert(emptyChunkedByteBuffer.toArray === Array.empty)
+    assert(emptyChunkedByteBuffer.toByteBuffer.capacity() === 0)
+    assert(emptyChunkedByteBuffer.toNetty.capacity() === 0)
+    emptyChunkedByteBuffer.toInputStream(dispose = false).close()
+    emptyChunkedByteBuffer.toInputStream(dispose = true).close()
+  }
+
+  test("chunks must be non-empty") {
+    intercept[IllegalArgumentException] {
+      new ChunkedByteBuffer(Array(ByteBuffer.allocate(0)))
+    }
+  }
+
+  test("getChunks() duplicates chunks") {
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
+    chunkedByteBuffer.getChunks().head.position(4)
+    assert(chunkedByteBuffer.getChunks().head.position() === 0)
+  }
+
+  test("copy() does not affect original buffer's position") {
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
+    chunkedByteBuffer.copy()
+    assert(chunkedByteBuffer.getChunks().head.position() === 0)
+  }
+
+  test("writeFully() does not affect original buffer's position") {
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
+    chunkedByteBuffer.writeFully(new ByteArrayWritableChannel(chunkedByteBuffer.size.toInt))
+    assert(chunkedByteBuffer.getChunks().head.position() === 0)
+  }
+
+  test("toArray()") {
+    val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes))
+    assert(chunkedByteBuffer.toArray === bytes.array() ++ bytes.array())
+  }
+
+  test("toArray() throws UnsupportedOperationException if size exceeds 2GB") {
+    val fourMegabyteBuffer = ByteBuffer.allocate(1024 * 1024 * 4)
+    fourMegabyteBuffer.limit(fourMegabyteBuffer.capacity())
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array.fill(1024)(fourMegabyteBuffer))
+    assert(chunkedByteBuffer.size === (1024L * 1024L * 1024L * 4L))
+    intercept[UnsupportedOperationException] {
+      chunkedByteBuffer.toArray
+    }
+  }
+
+  test("toInputStream()") {
+    val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte))
+    val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte))
+    val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes1, bytes2))
+    assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit())
+
+    val inputStream = chunkedByteBuffer.toInputStream(dispose = false)
+    val bytesFromStream = new Array[Byte](chunkedByteBuffer.size.toInt)
+    ByteStreams.readFully(inputStream, bytesFromStream)
+    assert(bytesFromStream === bytes1.array() ++ bytes2.array())
+    assert(chunkedByteBuffer.getChunks().head.position() === 0)
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index fe83fc722a8e8..7ee76aa4c6f9d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
 
 import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.language.implicitConversions
+import scala.reflect.ClassTag
 
 import org.scalatest.BeforeAndAfterEach
 import org.scalatest.time.SpanSugar._
@@ -52,7 +53,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   private def newBlockInfo(): BlockInfo = {
-    new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false)
+    new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)
   }
 
   private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index b78a3648cd8bc..98e8450fa1453 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.network.BlockTransferService
 import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
 import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.storage.StorageLevel._
 
@@ -62,7 +62,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
     val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
-    val store = new BlockManager(name, rpcEnv, master, serializer, conf,
+    val serializerManager = new SerializerManager(serializer, conf)
+    val store = new BlockManager(name, rpcEnv, master, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(store.memoryStore)
     store.initialize("app-id")
@@ -262,7 +263,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
     when(failableTransfer.hostName).thenReturn("some-hostname")
     when(failableTransfer.port).thenReturn(1000)
     val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1)
-    val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf,
+    val serializerManager = new SerializerManager(serializer, conf)
+    val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
     memManager.setMemoryStore(failableStore.memoryStore)
     failableStore.initialize("app-id")
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 2e0c0596a75bb..9419dfaa00648 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -24,6 +24,7 @@ import scala.concurrent.duration._
 import scala.concurrent.Future
 import scala.language.implicitConversions
 import scala.language.postfixOps
+import scala.reflect.ClassTag
 
 import org.mockito.{Matchers => mc}
 import org.mockito.Mockito.{mock, times, verify, when}
@@ -40,10 +41,11 @@ import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.network.shuffle.BlockFetchingListener
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager}
 import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
 import org.apache.spark.util._
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach
   with PrivateMethodTester with ResetSystemProperties {
@@ -76,7 +78,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val transfer = transferService
       .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1))
     val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
-    val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
+    val serializerManager = new SerializerManager(serializer, conf)
+    val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(blockManager.memoryStore)
     blockManager.initialize("app-id")
@@ -192,8 +195,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(master.getLocations("a3").size === 0, "master was told about a3")
 
     // Drop a1 and a2 from memory; this should be reported back to the master
-    store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer])
-    store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer])
+    store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ChunkedByteBuffer])
+    store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ChunkedByteBuffer])
     assert(store.getSingleAndReleaseLock("a1") === None, "a1 not removed from store")
     assert(store.getSingleAndReleaseLock("a2") === None, "a2 not removed from store")
     assert(master.getLocations("a1").size === 0, "master did not remove a1")
@@ -434,8 +437,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       t2.join()
       t3.join()
 
-      store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ByteBuffer])
-      store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ByteBuffer])
+      store.dropFromMemoryIfExists("a1", () => null: Either[Array[Any], ChunkedByteBuffer])
+      store.dropFromMemoryIfExists("a2", () => null: Either[Array[Any], ChunkedByteBuffer])
       store.waitForAsyncReregister()
     }
   }
@@ -820,8 +823,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       maxOnHeapExecutionMemory = Long.MaxValue,
       maxStorageMemory = 1200,
       numCores = 1)
+    val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
     store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
-      new JavaSerializer(conf), conf, memoryManager, mapOutputTracker,
+      serializerManager, conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, 0)
     memoryManager.setMemoryStore(store.memoryStore)
 
@@ -1073,7 +1077,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll with all the space in the world. This should succeed.
-    var putResult = memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY)
+    var putResult =
+      memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
     assert(putResult.isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
@@ -1084,7 +1089,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     // Unroll with not enough space. This should succeed after kicking out someBlock1.
     assert(store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY))
     assert(store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY))
-    putResult = memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY)
+    putResult =
+      memoryStore.putIterator("unroll", smallList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
     assert(putResult.isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     assert(memoryStore.contains("someBlock2"))
@@ -1098,7 +1104,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
     // In the mean time, however, we kicked out someBlock2 before giving up.
     assert(store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY))
-    putResult = memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY)
+    putResult =
+      memoryStore.putIterator("unroll", bigList.iterator, StorageLevel.MEMORY_ONLY, ClassTag.Any)
     assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
     assert(!memoryStore.contains("someBlock2"))
     assert(putResult.isLeft)
@@ -1120,8 +1127,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll with plenty of space. This should succeed and cache both blocks.
-    val result1 = memoryStore.putIterator("b1", smallIterator, memOnly)
-    val result2 = memoryStore.putIterator("b2", smallIterator, memOnly)
+    val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any)
+    val result2 = memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any)
     assert(memoryStore.contains("b1"))
     assert(memoryStore.contains("b2"))
     assert(result1.isRight) // unroll was successful
@@ -1136,7 +1143,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     store.putIterator("b2", smallIterator, memOnly)
 
     // Unroll with not enough space. This should succeed but kick out b1 in the process.
-    val result3 = memoryStore.putIterator("b3", smallIterator, memOnly)
+    val result3 = memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any)
     assert(result3.isRight)
     assert(!memoryStore.contains("b1"))
     assert(memoryStore.contains("b2"))
@@ -1146,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     store.putIterator("b3", smallIterator, memOnly)
 
     // Unroll huge block with not enough space. This should fail and kick out b2 in the process.
-    val result4 = memoryStore.putIterator("b4", bigIterator, memOnly)
+    val result4 = memoryStore.putIterator("b4", bigIterator, memOnly, ClassTag.Any)
     assert(result4.isLeft) // unroll was unsuccessful
     assert(!memoryStore.contains("b1"))
     assert(!memoryStore.contains("b2"))
@@ -1174,7 +1181,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
 
     // Unroll with not enough space. This should succeed but kick out b1 in the process.
     // Memory store should contain b2 and b3, while disk store should contain only b1
-    val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk)
+    val result3 = memoryStore.putIterator("b3", smallIterator, memAndDisk, ClassTag.Any)
     assert(result3.isRight)
     assert(!memoryStore.contains("b1"))
     assert(memoryStore.contains("b2"))
@@ -1190,7 +1197,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     // the block may be stored to disk. During the unrolling process, block "b2" should be kicked
     // out, so the memory store should contain only b3, while the disk store should contain
     // b1, b2 and b4.
-    val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk)
+    val result4 = memoryStore.putIterator("b4", bigIterator, memAndDisk, ClassTag.Any)
     assert(result4.isLeft)
     assert(!memoryStore.contains("b1"))
     assert(!memoryStore.contains("b2"))
@@ -1210,28 +1217,28 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // All unroll memory used is released because putIterator did not return an iterator
-    assert(memoryStore.putIterator("b1", smallIterator, memOnly).isRight)
+    assert(memoryStore.putIterator("b1", smallIterator, memOnly, ClassTag.Any).isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-    assert(memoryStore.putIterator("b2", smallIterator, memOnly).isRight)
+    assert(memoryStore.putIterator("b2", smallIterator, memOnly, ClassTag.Any).isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll memory is not released because putIterator returned an iterator
     // that still depends on the underlying vector used in the process
-    assert(memoryStore.putIterator("b3", smallIterator, memOnly).isLeft)
+    assert(memoryStore.putIterator("b3", smallIterator, memOnly, ClassTag.Any).isLeft)
     val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB3 > 0)
 
     // The unroll memory owned by this thread builds on top of its value after the previous unrolls
-    assert(memoryStore.putIterator("b4", smallIterator, memOnly).isLeft)
+    assert(memoryStore.putIterator("b4", smallIterator, memOnly, ClassTag.Any).isLeft)
     val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
 
     // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
-    assert(memoryStore.putIterator("b5", smallIterator, memOnly).isLeft)
+    assert(memoryStore.putIterator("b5", smallIterator, memOnly, ClassTag.Any).isLeft)
     val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask
-    assert(memoryStore.putIterator("b6", smallIterator, memOnly).isLeft)
+    assert(memoryStore.putIterator("b6", smallIterator, memOnly, ClassTag.Any).isLeft)
     val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask
-    assert(memoryStore.putIterator("b7", smallIterator, memOnly).isLeft)
+    assert(memoryStore.putIterator("b7", smallIterator, memOnly, ClassTag.Any).isLeft)
     val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
     assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
@@ -1243,7 +1250,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val memoryStore = store.memoryStore
     val blockId = BlockId("rdd_3_10")
     store.blockInfoManager.lockNewBlockForWriting(
-      blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, tellMaster = false))
+      blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false))
     memoryStore.putBytes(blockId, 13000, () => {
       fail("A big ByteBuffer that cannot be put into MemoryStore should not be created")
     })
@@ -1253,9 +1260,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     store = makeBlockManager(12000)
     val memoryStore = store.memoryStore
     val blockId = BlockId("rdd_3_10")
-    var bytes: ByteBuffer = null
+    var bytes: ChunkedByteBuffer = null
     memoryStore.putBytes(blockId, 10000, () => {
-      bytes = ByteBuffer.allocate(10000)
+      bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000))
       bytes
     })
     assert(memoryStore.getSize(blockId) === 10000)
@@ -1339,7 +1346,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         port: Int, execId: String,
         blockId: BlockId,
         blockData: ManagedBuffer,
-        level: StorageLevel): Future[Unit] = {
+        level: StorageLevel,
+        classTag: ClassTag[_]): Future[Unit] = {
       import scala.concurrent.ExecutionContext.Implicits.global
       Future {}
     }
@@ -1364,7 +1372,7 @@ private object BlockManagerSuite {
 
     def dropFromMemoryIfExists(
         blockId: BlockId,
-        data: () => Either[Array[Any], ByteBuffer]): Unit = {
+        data: () => Either[Array[Any], ChunkedByteBuffer]): Unit = {
       store.blockInfoManager.lockForWriting(blockId).foreach { info =>
         val newEffectiveStorageLevel = store.dropFromMemory(blockId, data)
         if (newEffectiveStorageLevel.isValid) {
@@ -1394,7 +1402,9 @@ private object BlockManagerSuite {
     val getLocalAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.getLocalValues)
     val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get)
     val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle)
-    val getLocalBytesAndReleaseLock: (BlockId) => Option[ByteBuffer] = wrapGet(store.getLocalBytes)
+    val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = {
+      wrapGet(store.getLocalBytes)
+    }
   }
 
 }
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
index 97e74fe706002..9ed5016510d56 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala
@@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer}
 import java.util.Arrays
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.util.io.ChunkedByteBuffer
 
 class DiskStoreSuite extends SparkFunSuite {
 
@@ -29,7 +30,7 @@ class DiskStoreSuite extends SparkFunSuite {
 
     // Create a non-trivial (not all zeros) byte array
     val bytes = Array.tabulate[Byte](1000)(_.toByte)
-    val byteBuffer = ByteBuffer.wrap(bytes)
+    val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes))
 
     val blockId = BlockId("rdd_1_2")
     val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true)
@@ -44,9 +45,10 @@ class DiskStoreSuite extends SparkFunSuite {
     val notMapped = diskStoreNotMapped.getBytes(blockId)
 
     // Not possible to do isInstanceOf due to visibility of HeapByteBuffer
-    assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"),
+    assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")),
       "Expected HeapByteBuffer for un-mapped read")
-    assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read")
+    assert(mapped.getChunks().forall(_.isInstanceOf[MappedByteBuffer]),
+      "Expected MappedByteBuffer for mapped read")
 
     def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = {
       val array = new Array[Byte](in.remaining())
@@ -54,9 +56,7 @@ class DiskStoreSuite extends SparkFunSuite {
       array
     }
 
-    val mappedAsArray = arrayFromByteBuffer(mapped)
-    val notMappedAsArray = arrayFromByteBuffer(notMapped)
-    assert(Arrays.equals(mappedAsArray, bytes))
-    assert(Arrays.equals(notMappedAsArray, bytes))
+    assert(Arrays.equals(mapped.toArray, bytes))
+    assert(Arrays.equals(notMapped.toArray, bytes))
   }
 }
diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml
index 9242be3d0357a..a1a88ac8cdac5 100644
--- a/dev/checkstyle-suppressions.xml
+++ b/dev/checkstyle-suppressions.xml
@@ -28,6 +28,12 @@
 -->
 
 
-
+    
+    
+    
+    
 
diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml
index 2261cc95d43ad..b66dca9041f2f 100644
--- a/dev/checkstyle.xml
+++ b/dev/checkstyle.xml
@@ -76,13 +76,10 @@
             
             
         
-        
-        
         
         
             
@@ -167,5 +164,7 @@
         
         
         
+        
+        
     
 
diff --git a/docs/building-spark.md b/docs/building-spark.md
index e478954c6267b..1e202acb9e2cf 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -98,8 +98,11 @@ mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package
 # Apache Hadoop 2.4.X or 2.5.X
 mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package
 
-Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were
-released after this version of Spark).
+# Apache Hadoop 2.6.X
+mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package
+
+# Apache Hadoop 2.7.X and later
+mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=VERSION -DskipTests clean package
 
 # Different versions of HDFS and YARN.
 mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package
@@ -140,10 +143,10 @@ It's possible to build Spark sub-modules using the `mvn -pl` option.
 For instance, you can build the Spark Streaming module using:
 
 {% highlight bash %}
-mvn -pl :spark-streaming_2.10 clean install
+mvn -pl :spark-streaming_2.11 clean install
 {% endhighlight %}
 
-where `spark-streaming_2.10` is the `artifactId` as defined in `streaming/pom.xml` file.
+where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file.
 
 # Continuous Compilation
 
diff --git a/docs/index.md b/docs/index.md
index 9dfc52a2bdc9b..20eab567a50df 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -130,8 +130,8 @@ options for deployment:
 * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark)
 * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here
 * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and
-  exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/),
-  [slides](http://ampcamp.berkeley.edu/3/) and [exercises](http://ampcamp.berkeley.edu/3/exercises/) are
+  exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/6/),
+  [slides](http://ampcamp.berkeley.edu/6/) and [exercises](http://ampcamp.berkeley.edu/6/exercises/) are
   available online for free.
 * [Code Examples](http://spark.apache.org/examples.html): more are also available in the `examples` subfolder of Spark ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples),
  [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples),
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 3a832de95f10d..293a82882e412 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -167,8 +167,8 @@ For example:
 ./bin/spark-submit \
   --class org.apache.spark.examples.SparkPi \
   --master mesos://207.184.161.138:7077 \
-  --deploy-mode cluster
-  --supervise
+  --deploy-mode cluster \
+  --supervise \
   --executor-memory 20G \
   --total-executor-cores 100 \
   http://path/to/examples.jar \
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 8045f8c5b8483..c775fe710ffd5 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -49,8 +49,8 @@ In `cluster` mode, the driver runs on a different machine than the client, so `S
     $ ./bin/spark-submit --class my.main.Class \
         --master yarn \
         --deploy-mode cluster \
-        --jars my-other-jar.jar,my-other-other-jar.jar
-        my-main-jar.jar
+        --jars my-other-jar.jar,my-other-other-jar.jar \
+        my-main-jar.jar \
         app_arg1 app_arg2
 
 
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index c3ef93c5b6325..229d1234414e5 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -84,13 +84,14 @@ public static void main(String[] args) throws Exception {
     JavaRDD lines = ctx.textFile(args[0], 1);
 
     // Loads all URLs from input file and initialize their neighbors.
-    JavaPairRDD> links = lines.mapToPair(new PairFunction() {
-      @Override
-      public Tuple2 call(String s) {
-        String[] parts = SPACES.split(s);
-        return new Tuple2<>(parts[0], parts[1]);
-      }
-    }).distinct().groupByKey().cache();
+    JavaPairRDD> links = lines.mapToPair(
+      new PairFunction() {
+        @Override
+        public Tuple2 call(String s) {
+          String[] parts = SPACES.split(s);
+          return new Tuple2<>(parts[0], parts[1]);
+        }
+      }).distinct().groupByKey().cache();
 
     // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one.
     JavaPairRDD ranks = links.mapValues(new Function, Double>() {
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
index 84dbea5caa135..3ff5412b934f0 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
@@ -52,19 +52,21 @@ public Iterator call(String s) {
       }
     });
 
-    JavaPairRDD ones = words.mapToPair(new PairFunction() {
-      @Override
-      public Tuple2 call(String s) {
-        return new Tuple2<>(s, 1);
-      }
-    });
+    JavaPairRDD ones = words.mapToPair(
+      new PairFunction() {
+        @Override
+        public Tuple2 call(String s) {
+          return new Tuple2<>(s, 1);
+        }
+      });
 
-    JavaPairRDD counts = ones.reduceByKey(new Function2() {
-      @Override
-      public Integer call(Integer i1, Integer i2) {
-        return i1 + i2;
-      }
-    });
+    JavaPairRDD counts = ones.reduceByKey(
+      new Function2() {
+        @Override
+        public Integer call(Integer i1, Integer i2) {
+          return i1 + i2;
+        }
+      });
 
     List> output = counts.collect();
     for (Tuple2 tuple : output) {
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
index 5bd61fe508bd5..8214952f80695 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java
@@ -39,7 +39,10 @@ public static void main(String[] args) {
 
     // $example on$
     // Load the data stored in LIBSVM format as a DataFrame.
-    Dataset data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+    Dataset data = sqlContext
+      .read()
+      .format("libsvm")
+      .load("data/mllib/sample_libsvm_data.txt");
 
     // Index labels, adding metadata to the label column.
     // Fit on whole dataset to include all labels in index.
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 8a10dd48aa72f..fbd881766983f 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -107,11 +107,11 @@ public static void main(String[] args) throws Exception {
 class MyJavaLogisticRegression
   extends Classifier {
 
-  public MyJavaLogisticRegression() {
+  MyJavaLogisticRegression() {
     init();
   }
 
-  public MyJavaLogisticRegression(String uid) {
+  MyJavaLogisticRegression(String uid) {
     this.uid_ = uid;
     init();
   }
@@ -177,7 +177,7 @@ class MyJavaLogisticRegressionModel
   private Vector coefficients_;
   public Vector coefficients() { return coefficients_; }
 
-  public MyJavaLogisticRegressionModel(String uid, Vector coefficients) {
+  MyJavaLogisticRegressionModel(String uid, Vector coefficients) {
     this.uid_ = uid;
     this.coefficients_ = coefficients;
   }
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
index c2cb9553858f8..553070dace882 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java
@@ -40,7 +40,8 @@ public static void main(String[] args) {
 
     // $example on$
     // Load and parse the data file, converting it to a DataFrame.
-    Dataset data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+    Dataset data = sqlContext.read().format("libsvm")
+      .load("data/mllib/sample_libsvm_data.txt");
 
     // Index labels, adding metadata to the label column.
     // Fit on whole dataset to include all labels in index.
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java
index 3d8babba04a53..7561a1f6535d6 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java
@@ -65,7 +65,8 @@ public Tuple2 call(LabeledPoint p) {
     );
 
     // Get evaluation metrics.
-    BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd());
+    BinaryClassificationMetrics metrics =
+      new BinaryClassificationMetrics(predictionAndLabels.rdd());
 
     // Precision by threshold
     JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD();
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java
index 0e15f755083bf..c6361a3729988 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java
@@ -48,7 +48,8 @@ public Tuple3 call(String line) {
     );
 
     // Split data into training (60%) and test (40%) sets.
-    JavaRDD>[] splits = parsedData.randomSplit(new double[]{0.6, 0.4}, 11L);
+    JavaRDD>[] splits =
+        parsedData.randomSplit(new double[]{0.6, 0.4}, 11L);
     JavaRDD> training = splits[0];
     JavaRDD> test = splits[1];
 
@@ -80,7 +81,8 @@ public Object call(Tuple2 pl) {
 
     // Save and load model
     model.save(jsc.sc(), "target/tmp/myIsotonicRegressionModel");
-    IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel");
+    IsotonicRegressionModel sameModel =
+      IsotonicRegressionModel.load(jsc.sc(), "target/tmp/myIsotonicRegressionModel");
     // $example off$
 
     jsc.stop();
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java
index 2197ef9481a79..984909cb947a1 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStreamingTestExample.java
@@ -18,7 +18,6 @@
 package org.apache.spark.examples.mllib;
 
 
-import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.function.VoidFunction;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -56,6 +55,9 @@
  * batches processed exceeds `numBatchesTimeout`.
  */
 public class JavaStreamingTestExample {
+
+  private static int timeoutCounter = 0;
+
   public static void main(String[] args) {
     if (args.length != 3) {
       System.err.println("Usage: JavaStreamingTestExample " +
@@ -76,7 +78,7 @@ public static void main(String[] args) {
     JavaDStream data = ssc.textFileStream(dataDir).map(
       new Function() {
         @Override
-        public BinarySample call(String line) throws Exception {
+        public BinarySample call(String line) {
           String[] ts = line.split(",");
           boolean label = Boolean.valueOf(ts[0]);
           double value = Double.valueOf(ts[1]);
@@ -94,22 +96,21 @@ public BinarySample call(String line) throws Exception {
     // $example off$
 
     // Stop processing if test becomes significant or we time out
-    final Accumulator timeoutCounter =
-      ssc.sparkContext().accumulator(numBatchesTimeout);
+    timeoutCounter = numBatchesTimeout;
 
     out.foreachRDD(new VoidFunction>() {
       @Override
-      public void call(JavaRDD rdd) throws Exception {
-        timeoutCounter.add(-1);
+      public void call(JavaRDD rdd) {
+        timeoutCounter -= 1;
 
-        long cntSignificant = rdd.filter(new Function() {
+        boolean anySignificant = !rdd.filter(new Function() {
           @Override
-          public Boolean call(StreamingTestResult v) throws Exception {
+          public Boolean call(StreamingTestResult v) {
             return v.pValue() < 0.05;
           }
-        }).count();
+        }).isEmpty();
 
-        if (timeoutCounter.value() <= 0 || cntSignificant > 0) {
+        if (timeoutCounter <= 0 || anySignificant) {
           rdd.context().stop();
         }
       }
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
index bfbad91e4fdfa..769b21cecfb80 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -40,7 +40,8 @@
  *    is a list of one or more kafka topics to consume from
  *
  * Example:
- *    $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port topic1,topic2
+ *    $ bin/run-example streaming.JavaDirectKafkaWordCount broker1-host:port,broker2-host:port \
+ *      topic1,topic2
  */
 
 public final class JavaDirectKafkaWordCount {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java
index 426eaa5f0adea..62413b4606ff2 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaQueueStream.java
@@ -30,7 +30,6 @@
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.examples.streaming.StreamingExamples;
 import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
index a597ecbc5bcb3..e5fb2bfbfae7b 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
@@ -155,9 +155,11 @@ public Integer call(Integer i1, Integer i2) {
       @Override
       public void call(JavaPairRDD rdd, Time time) throws IOException {
         // Get or register the blacklist Broadcast
-        final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context()));
+        final Broadcast> blacklist =
+            JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context()));
         // Get or register the droppedWordsCounter Accumulator
-        final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context()));
+        final Accumulator droppedWordsCounter =
+            JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context()));
         // Use blacklist to drop words and use droppedWordsCounter to count them
         String counts = rdd.filter(new Function, Boolean>() {
           @Override
@@ -210,7 +212,8 @@ public JavaStreamingContext call() {
       }
     };
 
-    JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, createContextFunc);
+    JavaStreamingContext ssc =
+      JavaStreamingContext.getOrCreate(checkpointDirectory, createContextFunc);
     ssc.start();
     ssc.awaitTermination();
   }
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index 6beab90f086d8..4230dab52e5d4 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -91,7 +91,8 @@ public Tuple2 call(String s) {
     Function3, State, Tuple2> mappingFunc =
         new Function3, State, Tuple2>() {
           @Override
-          public Tuple2 call(String word, Optional one, State state) {
+          public Tuple2 call(String word, Optional one,
+              State state) {
             int sum = one.orElse(0) + (state.exists() ? state.get() : 0);
             Tuple2 output = new Tuple2<>(word, sum);
             state.update(sum);
diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
index 5dc825dfdc911..0e43e9272d7c3 100644
--- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
+++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
@@ -140,7 +140,8 @@ public static void main(String[] args) {
     for (int i = 0; i < numStreams; i++) {
       streamsList.add(
           KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
-              InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2())
+              InitialPositionInStream.LATEST, kinesisCheckpointInterval,
+              StorageLevel.MEMORY_AND_DISK_2())
       );
     }
 
@@ -167,7 +168,7 @@ public Iterator call(byte[] line) {
         new PairFunction() {
           @Override
           public Tuple2 call(String s) {
-            return new Tuple2(s, 1);
+            return new Tuple2<>(s, 1);
           }
         }
     ).reduceByKey(
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index 15ac588b82587..a0007d33d6257 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -221,51 +221,6 @@ object KinesisUtils {
     }
   }
 
-  /**
-   * Create an input stream that pulls messages from a Kinesis stream.
-   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
-   *
-   * Note:
-   *
-   *  - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
-   *    on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
-   *    gets AWS credentials.
-   *  - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch.
-   *  - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name
-   *    in [[org.apache.spark.SparkConf]].
-   *
-   * @param ssc StreamingContext object
-   * @param streamName   Kinesis stream name
-   * @param endpointUrl  Endpoint url of Kinesis service
-   *                     (e.g., https://kinesis.us-east-1.amazonaws.com)
-   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
-   *                            See the Kinesis Spark Streaming documentation for more
-   *                            details on the different types of checkpoints.
-   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
-   *                                 worker's initial starting position in the stream.
-   *                                 The values are either the beginning of the stream
-   *                                 per Kinesis' limit of 24 hours
-   *                                 (InitialPositionInStream.TRIM_HORIZON) or
-   *                                 the tip of the stream (InitialPositionInStream.LATEST).
-   * @param storageLevel Storage level to use for storing the received objects
-   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
-   */
-  @deprecated("use other forms of createStream", "1.4.0")
-  def createStream(
-      ssc: StreamingContext,
-      streamName: String,
-      endpointUrl: String,
-      checkpointInterval: Duration,
-      initialPositionInStream: InitialPositionInStream,
-      storageLevel: StorageLevel
-    ): ReceiverInputDStream[Array[Byte]] = {
-    ssc.withNamedScope("kinesis stream") {
-      new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl,
-        getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName,
-        checkpointInterval, storageLevel, defaultMessageHandler, None)
-    }
-  }
-
   /**
    * Create an input stream that pulls messages from a Kinesis stream.
    * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
@@ -453,47 +408,6 @@ object KinesisUtils {
       defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
   }
 
-  /**
-   * Create an input stream that pulls messages from a Kinesis stream.
-   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
-   *
-   * Note:
-   * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
-   *   on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
-   *   gets AWS credentials.
-   * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch.
-   * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in
-   *   [[org.apache.spark.SparkConf]].
-   *
-   * @param jssc Java StreamingContext object
-   * @param streamName   Kinesis stream name
-   * @param endpointUrl  Endpoint url of Kinesis service
-   *                     (e.g., https://kinesis.us-east-1.amazonaws.com)
-   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
-   *                            See the Kinesis Spark Streaming documentation for more
-   *                            details on the different types of checkpoints.
-   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
-   *                                 worker's initial starting position in the stream.
-   *                                 The values are either the beginning of the stream
-   *                                 per Kinesis' limit of 24 hours
-   *                                 (InitialPositionInStream.TRIM_HORIZON) or
-   *                                 the tip of the stream (InitialPositionInStream.LATEST).
-   * @param storageLevel Storage level to use for storing the received objects
-   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
-   */
-  @deprecated("use other forms of createStream", "1.4.0")
-  def createStream(
-      jssc: JavaStreamingContext,
-      streamName: String,
-      endpointUrl: String,
-      checkpointInterval: Duration,
-      initialPositionInStream: InitialPositionInStream,
-      storageLevel: StorageLevel
-    ): JavaReceiverInputDStream[Array[Byte]] = {
-    createStream(
-      jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)
-  }
-
   private def getRegionByEndpoint(endpointUrl: String): String = {
     RegionUtils.getRegionByEndpoint(endpointUrl).getName()
   }
diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
index 5c2371c5430b3..f078973c6c285 100644
--- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
+++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.streaming.kinesis;
 
+import com.amazonaws.regions.RegionUtils;
 import com.amazonaws.services.kinesis.model.Record;
 import org.junit.Test;
 
@@ -34,11 +35,13 @@
 public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
   @Test
   public void testKinesisStream() {
-    // Tests the API, does not actually test data receiving
-    JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream",
-        "https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
-        InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2());
+    String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl();
+    String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName();
 
+    // Tests the API, does not actually test data receiving
+    JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
+        dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000),
+        StorageLevel.MEMORY_AND_DISK_2());
     ssc.stop();
   }
 
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
index ee428f31d6ce3..1c81298a7c201 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -40,7 +40,7 @@ trait KinesisFunSuite extends SparkFunSuite  {
     if (shouldRunTests) {
       body
     } else {
-      ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")()
+      ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")(())
     }
   }
 }
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 4460b6bccaa81..0e71bf9b84332 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -99,14 +99,10 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
   }
 
   test("KinesisUtils API") {
-    // Tests the API, does not actually test data receiving
-    val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
-      dummyEndpointUrl, Seconds(2),
-      InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2)
-    val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+    val kinesisStream1 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
       dummyEndpointUrl, dummyRegionName,
       InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
-    val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+    val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
       dummyEndpointUrl, dummyRegionName,
       InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
       dummyAWSAccessKey, dummyAWSSecretKey)
@@ -154,7 +150,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
 
     // Verify that KinesisBackedBlockRDD is generated even when there are no blocks
     val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
-    emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
+    // Verify it's KinesisBackedBlockRDD[_] rather than KinesisBackedBlockRDD[Array[Byte]], because
+    // the type parameter will be erased at runtime
+    emptyRDD shouldBe a [KinesisBackedBlockRDD[_]]
     emptyRDD.partitions shouldBe empty
 
     // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index f6c7e07654ee9..587fda7a3c1da 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -57,7 +57,7 @@ abstract class AbstractCommandBuilder {
   // properties files multiple times.
   private Map effectiveConfig;
 
-  public AbstractCommandBuilder() {
+  AbstractCommandBuilder() {
     this.appArgs = new ArrayList<>();
     this.childEnv = new HashMap<>();
     this.conf = new HashMap<>();
diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
index 37afafea28fdc..39fdf300e26cd 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
@@ -32,7 +32,7 @@ class CommandBuilderUtils {
   static final String ENV_SPARK_HOME = "SPARK_HOME";
 
   /** The set of known JVM vendors. */
-  static enum JavaVendor {
+  enum JavaVendor {
     Oracle, IBM, OpenJDK, Unknown
   };
 
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
index e9caf0b3cb063..625d02632114a 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
@@ -32,7 +32,7 @@ public interface SparkAppHandle {
    *
    * @since 1.6.0
    */
-  public enum State {
+  enum State {
     /** The application has not reported back yet. */
     UNKNOWN(false),
     /** The application has connected to the handle. */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 77e59d9188ef4..861b1d4b66f23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -509,12 +509,8 @@ class LogisticRegressionModel private[spark] (
    * thrown if `trainingSummary == None`.
    */
   @Since("1.5.0")
-  def summary: LogisticRegressionTrainingSummary = trainingSummary match {
-    case Some(summ) => summ
-    case None =>
-      throw new SparkException(
-        "No training summary available for this LogisticRegressionModel",
-        new NullPointerException())
+  def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
+    throw new SparkException("No training summary available for this LogisticRegressionModel")
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index ab00127899edf..38428826a8a7d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -148,12 +148,9 @@ class KMeansModel private[ml] (
    * thrown if `trainingSummary == None`.
    */
   @Since("2.0.0")
-  def summary: KMeansSummary = trainingSummary match {
-    case Some(summ) => summ
-    case None =>
-      throw new SparkException(
-        s"No training summary available for the ${this.getClass.getSimpleName}",
-        new NullPointerException())
+  def summary: KMeansSummary = trainingSummary.getOrElse {
+    throw new SparkException(
+      s"No training summary available for the ${this.getClass.getSimpleName}")
   }
 }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index a3845d39777a4..5694b3890fba4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -207,13 +207,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
   def setMinTF(value: Double): this.type = set(minTF, value)
 
   /**
-    * Binary toggle to control the output vector values.
-    * If True, all non zero counts are set to 1. This is useful for discrete probabilistic
-    * models that model binary events rather than integer counts
-    *
-    * Default: false
-    * @group param
-    */
+   * Binary toggle to control the output vector values.
+   * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
+   * discrete probabilistic models that model binary events rather than integer counts.
+   * Default: false
+   * @group param
+   */
   val binary: BooleanParam =
     new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
       "This is useful for discrete probabilistic models that model binary events rather " +
@@ -248,17 +247,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
         }
         tokenCount += 1
       }
-      val effectiveMinTF = if (minTf >= 1.0) {
-        minTf
-      } else {
-        tokenCount * minTf
-      }
+      val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf
       val effectiveCounts = if ($(binary)) {
         termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
-      }
-      else {
+      } else {
         termCounts.filter(_._2 >= effectiveMinTF).toSeq
       }
+
       Vectors.sparse(dictBr.value.size, effectiveCounts)
     }
     dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index ab5f4a1a9a6c4..e7ca7ada74c8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.VectorUDT
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types._
@@ -68,7 +70,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
  * will be created from the specified response variable in the formula.
  */
 @Experimental
-class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
+class RFormula(override val uid: String)
+  extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("rFormula"))
 
@@ -180,6 +183,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
   override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)"
 }
 
+@Since("2.0.0")
+object RFormula extends DefaultParamsReadable[RFormula] {
+
+  @Since("2.0.0")
+  override def load(path: String): RFormula = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * A fitted RFormula. Fitting is required to determine the factor levels of formula terms.
@@ -189,9 +199,9 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
 @Experimental
 class RFormulaModel private[feature](
     override val uid: String,
-    resolvedFormula: ResolvedRFormula,
-    pipelineModel: PipelineModel)
-  extends Model[RFormulaModel] with RFormulaBase {
+    private[ml] val resolvedFormula: ResolvedRFormula,
+    private[ml] val pipelineModel: PipelineModel)
+  extends Model[RFormulaModel] with RFormulaBase with MLWritable {
 
   override def transform(dataset: DataFrame): DataFrame = {
     checkCanTransform(dataset.schema)
@@ -246,14 +256,71 @@ class RFormulaModel private[feature](
       !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
       "Label column already exists and is not of type DoubleType.")
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new RFormulaModel.RFormulaModelWriter(this)
+}
+
+@Since("2.0.0")
+object RFormulaModel extends MLReadable[RFormulaModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[RFormulaModel] = new RFormulaModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): RFormulaModel = super.load(path)
+
+  /** [[MLWriter]] instance for [[RFormulaModel]] */
+  private[RFormulaModel] class RFormulaModelWriter(instance: RFormulaModel) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: resolvedFormula
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(instance.resolvedFormula))
+        .repartition(1).write.parquet(dataPath)
+      // Save pipeline model
+      val pmPath = new Path(path, "pipelineModel").toString
+      instance.pipelineModel.save(pmPath)
+    }
+  }
+
+  private class RFormulaModelReader extends MLReader[RFormulaModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[RFormulaModel].getName
+
+    override def load(path: String): RFormulaModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
+      val label = data.getString(0)
+      val terms = data.getAs[Seq[Seq[String]]](1)
+      val hasIntercept = data.getBoolean(2)
+      val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
+
+      val pmPath = new Path(path, "pipelineModel").toString
+      val pipelineModel = PipelineModel.load(pmPath)
+
+      val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**
  * Utility transformer for removing temporary columns from a DataFrame.
  * TODO(ekl) make this a public transformer
  */
-private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
-  override val uid = Identifiable.randomUID("columnPruner")
+private class ColumnPruner(override val uid: String, val columnsToPrune: Set[String])
+  extends Transformer with MLWritable {
+
+  def this(columnsToPrune: Set[String]) =
+    this(Identifiable.randomUID("columnPruner"), columnsToPrune)
 
   override def transform(dataset: DataFrame): DataFrame = {
     val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
@@ -265,6 +332,48 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
   }
 
   override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
+
+  override def write: MLWriter = new ColumnPruner.ColumnPrunerWriter(this)
+}
+
+private object ColumnPruner extends MLReadable[ColumnPruner] {
+
+  override def read: MLReader[ColumnPruner] = new ColumnPrunerReader
+
+  override def load(path: String): ColumnPruner = super.load(path)
+
+  /** [[MLWriter]] instance for [[ColumnPruner]] */
+  private[ColumnPruner] class ColumnPrunerWriter(instance: ColumnPruner) extends MLWriter {
+
+    private case class Data(columnsToPrune: Seq[String])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: columnsToPrune
+      val data = Data(instance.columnsToPrune.toSeq)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class ColumnPrunerReader extends MLReader[ColumnPruner] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[ColumnPruner].getName
+
+    override def load(path: String): ColumnPruner = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head()
+      val columnsToPrune = data.getAs[Seq[String]](0).toSet
+      val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
+
+      DefaultParamsReader.getAndSetParams(pruner, metadata)
+      pruner
+    }
+  }
 }
 
 /**
@@ -278,11 +387,13 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
  *                          by the value in the map.
  */
 private class VectorAttributeRewriter(
-    vectorCol: String,
-    prefixesToRewrite: Map[String, String])
-  extends Transformer {
+    override val uid: String,
+    val vectorCol: String,
+    val prefixesToRewrite: Map[String, String])
+  extends Transformer with MLWritable {
 
-  override val uid = Identifiable.randomUID("vectorAttrRewriter")
+  def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
+    this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
 
   override def transform(dataset: DataFrame): DataFrame = {
     val metadata = {
@@ -315,4 +426,48 @@ private class VectorAttributeRewriter(
   }
 
   override def copy(extra: ParamMap): VectorAttributeRewriter = defaultCopy(extra)
+
+  override def write: MLWriter = new VectorAttributeRewriter.VectorAttributeRewriterWriter(this)
+}
+
+private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewriter] {
+
+  override def read: MLReader[VectorAttributeRewriter] = new VectorAttributeRewriterReader
+
+  override def load(path: String): VectorAttributeRewriter = super.load(path)
+
+  /** [[MLWriter]] instance for [[VectorAttributeRewriter]] */
+  private[VectorAttributeRewriter]
+  class VectorAttributeRewriterWriter(instance: VectorAttributeRewriter) extends MLWriter {
+
+    private case class Data(vectorCol: String, prefixesToRewrite: Map[String, String])
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: vectorCol, prefixesToRewrite
+      val data = Data(instance.vectorCol, instance.prefixesToRewrite)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class VectorAttributeRewriterReader extends MLReader[VectorAttributeRewriter] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[VectorAttributeRewriter].getName
+
+    override def load(path: String): VectorAttributeRewriter = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head()
+      val vectorCol = data.getString(0)
+      val prefixesToRewrite = data.getAs[Map[String, String]](1)
+      val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
+
+      DefaultParamsReader.getAndSetParams(rewriter, metadata)
+      rewriter
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 6e74cb54ad682..0e71e8d8e1339 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -681,8 +681,7 @@ class GeneralizedLinearRegressionModel private[ml] (
   @Since("2.0.0")
   def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse {
     throw new SparkException(
-      "No training summary available for this GeneralizedLinearRegressionModel",
-      new RuntimeException())
+      "No training summary available for this GeneralizedLinearRegressionModel")
   }
 
   private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index c8f3f70a9b446..b81c588e44fcc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -398,12 +398,8 @@ class LinearRegressionModel private[ml] (
    * thrown if `trainingSummary == None`.
    */
   @Since("1.5.0")
-  def summary: LinearRegressionTrainingSummary = trainingSummary match {
-    case Some(summ) => summ
-    case None =>
-      throw new SparkException(
-        "No training summary available for this LinearRegressionModel",
-        new NullPointerException())
+  def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse {
+    throw new SparkException("No training summary available for this LinearRegressionModel")
   }
 
   private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index cf189e8e96f95..be356575ca09a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -22,15 +22,12 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
 
 
 /**
- * :: DeveloperApi ::
- *
  * Single-label regression
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[org.apache.spark.mllib.linalg.Vector]]
  * @tparam Learner  Concrete Estimator type
  * @tparam M  Concrete Model type
  */
-@DeveloperApi
 private[spark] abstract class Regressor[
     FeaturesType,
     Learner <: Regressor[FeaturesType, Learner, M],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index 526b9c40628a3..2c8286766f3bf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -171,7 +171,6 @@ private[spark] class NodeIdCache(
   }
 }
 
-@DeveloperApi
 private[spark] object NodeIdCache {
   /**
    * Initialize the node Id cache with initial node Id values.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 91dc98569a21b..afbb9d974d42a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         val featureSplits = splits(featureIndex)
@@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
         while (splitIndex < numSplits) {
           if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
           }
           splitIndex += 1
         }
@@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -479,8 +477,8 @@ private[ml] object RandomForest extends Logging {
         // Construct a nodeStatsAggregators array to hold node aggregate stats,
         // each node will have a nodeStatsAggregator
         val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
-          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
-            Some(nodeToFeatures(nodeIndex))
+          val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+            nodeToFeatures(nodeIndex)
           }
           new DTStatsAggregator(metadata, featuresForNode)
         }
@@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
 
     // Calculate InformationGain and ImpurityStats if current node is top node
     val level = LearningNode.indexToLevel(node.id)
-    var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+    var gainAndImpurityStats: ImpurityStats = if (level == 0) {
       null
     } else {
       node.stats
@@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
                 leftChildStats, rightChildStats, binAggregates.metadata)
               (splitIndex, gainAndImpurityStats)
@@ -830,8 +827,8 @@ private[ml] object RandomForest extends Logging {
     val numFeatures = metadata.numFeatures
 
     // Sample the input only if there are continuous features.
-    val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
-    val sampledInput = if (hasContinuousFeatures) {
+    val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+    val sampledInput = if (continuousFeatures.nonEmpty) {
       // Calculate the number of samples for approximate quantile calculation.
       val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
       val fraction = if (requiredSamples < metadata.numExamples) {
@@ -840,58 +837,57 @@ private[ml] object RandomForest extends Logging {
         1.0
       }
       logDebug("fraction of data used for calculating quantiles = " + fraction)
-      input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
+      input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
     } else {
-      new Array[LabeledPoint](0)
+      input.sparkContext.emptyRDD[LabeledPoint]
     }
 
-    val splits = new Array[Array[Split]](numFeatures)
-
-    // Find all splits.
-    // Iterate over all features.
-    var featureIndex = 0
-    while (featureIndex < numFeatures) {
-      if (metadata.isContinuous(featureIndex)) {
-        val featureSamples = sampledInput.map(_.features(featureIndex))
-        val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
+    findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
+  }
 
-        val numSplits = featureSplits.length
-        logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
-        splits(featureIndex) = new Array[Split](numSplits)
+  private def findSplitsBinsBySorting(
+      input: RDD[LabeledPoint],
+      metadata: DecisionTreeMetadata,
+      continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+    val continuousSplits: scala.collection.Map[Int, Array[Split]] = {
+      // reduce the parallelism for split computations when there are less
+      // continuous features than input partitions. this prevents tasks from
+      // being spun up that will definitely do no work.
+      val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+      input
+        .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
+        .groupByKey(numPartitions)
+        .map { case (idx, samples) =>
+          val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
+          val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+          logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+          (idx, splits)
+        }.collectAsMap()
+    }
 
-        var splitIndex = 0
-        while (splitIndex < numSplits) {
-          val threshold = featureSplits(splitIndex)
-          splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
-          splitIndex += 1
-        }
-      } else {
-        // Categorical feature
-        if (metadata.isUnordered(featureIndex)) {
-          val numSplits = metadata.numSplits(featureIndex)
-          val featureArity = metadata.featureArity(featureIndex)
-          // TODO: Use an implicit representation mapping each category to a subset of indices.
-          //       I.e., track indices such that we can calculate the set of bins for which
-          //       feature value x splits to the left.
-          // Unordered features
-          // 2^(maxFeatureValue - 1) - 1 combinations
-          splits(featureIndex) = new Array[Split](numSplits)
-          var splitIndex = 0
-          while (splitIndex < numSplits) {
-            val categories: List[Double] =
-              extractMultiClassCategories(splitIndex + 1, featureArity)
-            splits(featureIndex)(splitIndex) =
-              new CategoricalSplit(featureIndex, categories.toArray, featureArity)
-            splitIndex += 1
-          }
-        } else {
-          // Ordered features
-          //   Bins correspond to feature values, so we do not need to compute splits or bins
-          //   beforehand.  Splits are constructed as needed during training.
-          splits(featureIndex) = new Array[Split](0)
+    val numFeatures = metadata.numFeatures
+    val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+      case i if metadata.isContinuous(i) =>
+        val split = continuousSplits(i)
+        metadata.setNumSplits(i, split.length)
+        split
+
+      case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+        // Unordered features
+        // 2^(maxFeatureValue - 1) - 1 combinations
+        val featureArity = metadata.featureArity(i)
+        Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+          val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+          new CategoricalSplit(i, categories.toArray, featureArity)
         }
-      }
-      featureIndex += 1
+
+      case i if metadata.isCategorical(i) =>
+        // Ordered features
+        //   Bins correspond to feature values, so we do not need to compute splits or bins
+        //   beforehand.  Splits are constructed as needed during training.
+        Array.empty[Split]
     }
     splits
   }
@@ -933,7 +929,7 @@ private[ml] object RandomForest extends Logging {
    * @return array of splits
    */
   private[tree] def findSplitsForContinuousFeature(
-      featureSamples: Array[Double],
+      featureSamples: Iterable[Double],
       metadata: DecisionTreeMetadata,
       featureIndex: Int): Array[Double] = {
     require(metadata.isContinuous(featureIndex),
@@ -943,8 +939,9 @@ private[ml] object RandomForest extends Logging {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
-      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
-        m + ((x, m.getOrElse(x, 0) + 1))
+      val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+        case ((m, cnt), x) =>
+          (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
       }
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -955,7 +952,7 @@ private[ml] object RandomForest extends Logging {
         valueCounts.map(_._1)
       } else {
         // stride between splits
-        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        val stride: Double = numSamples.toDouble / (numSplits + 1)
         logDebug("stride = " + stride)
 
         // iterate `valueCount` to find splits
@@ -991,8 +988,6 @@ private[ml] object RandomForest extends Logging {
     assert(splits.length > 0,
       s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
         "  Please remove this feature and then try again.")
-    // set number of splits accordingly
-    metadata.setNumSplits(featureIndex, splits.length)
 
     splits
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 010e7d2686571..963f81cb3ec39 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,19 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
   }
 
   @Since("1.4.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): CrossValidator = {
@@ -221,10 +209,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
           // TODO: SPARK-11892: This case may require special handling.
           throw new UnsupportedOperationException("CrossValidator write will fail because it" +
             " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
-        case rform: RFormulaModel =>
-          // TODO: SPARK-11891: This case may require special handling.
-          throw new UnsupportedOperationException("CrossValidator write will fail because it" +
-            " cannot yet handle an estimator containing an RFormulaModel")
+        case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
         case _: Params => Array()
       }
       val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
@@ -334,11 +319,6 @@ class CrossValidatorModel private[ml] (
     @Since("1.5.0") val avgMetrics: Array[Double])
   extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
 
-  @Since("1.4.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.4.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -347,7 +327,6 @@ class CrossValidatorModel private[ml] (
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4587e259e8bf7..70fa5f0234753 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -117,19 +117,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
   }
 
   @Since("1.5.0")
-  override def transformSchema(schema: StructType): StructType = {
-    validateParams()
-    $(estimator).transformSchema(schema)
-  }
-
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    super.validateParams()
-    val est = $(estimator)
-    for (paramMap <- $(estimatorParamMaps)) {
-      est.copy(paramMap).validateParams()
-    }
-  }
+  override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): TrainValidationSplit = {
@@ -160,11 +148,6 @@ class TrainValidationSplitModel private[ml] (
     @Since("1.5.0") val validationMetrics: Array[Double])
   extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
 
-  @Since("1.5.0")
-  override def validateParams(): Unit = {
-    bestModel.validateParams()
-  }
-
   @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -173,7 +156,6 @@ class TrainValidationSplitModel private[ml] (
 
   @Since("1.5.0")
   override def transformSchema(schema: StructType): StructType = {
-    validateParams()
     bestModel.transformSchema(schema)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 553f254172410..953456e8f0dca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,20 +17,19 @@
 
 package org.apache.spark.ml.tuning
 
-import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.Estimator
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.sql.types.StructType
 
 /**
- * :: DeveloperApi ::
  * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
  */
-@DeveloperApi
 private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the estimator to be validated
+   *
    * @group param
    */
   val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
@@ -40,6 +39,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for estimator param maps
+   *
    * @group param
    */
   val estimatorParamMaps: Param[Array[ParamMap]] =
@@ -50,6 +50,7 @@ private[ml] trait ValidatorParams extends Params {
 
   /**
    * param for the evaluator used to select hyper-parameters that maximize the validated metric
+   *
    * @group param
    */
   val evaluator: Param[Evaluator] = new Param(this, "evaluator",
@@ -57,4 +58,14 @@ private[ml] trait ValidatorParams extends Params {
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
+
+  protected def transformSchemaImpl(schema: StructType): StructType = {
+    require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps")
+    val firstEstimatorParamMap = $(estimatorParamMaps).head
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps).tail) {
+      est.copy(paramMap).transformSchema(schema)
+    }
+    est.copy(firstEstimatorParamMap).transformSchema(schema)
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 157f2dbf5d7a0..c6de7751f57f4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -156,7 +156,6 @@ sealed trait Matrix extends Serializable {
   def numActives: Int
 }
 
-@DeveloperApi
 private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
 
   override def sqlType: StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 18f66e65f19ca..8f02e098acc30 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
 
   /**
    * Method to train a decision tree model over an RDD
+   *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return DecisionTreeModel that can be used for prediction.
    */
@@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
@@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
           if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
               instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
-              instanceWeight)
           }
           splitIndex += 1
         }
@@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               predictWithImpurity = Some(predictWithImpurity.getOrElse(
                 calculatePredictImpurity(leftChildStats, rightChildStats)))
               val gainStats = calculateGainForSplit(leftChildStats,
@@ -1013,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
         featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
       val splits = {
         val featureSplits = findSplitsForContinuousFeature(
-          featureSamples.toArray,
+          featureSamples,
           metadata,
           featureIndex)
         logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
@@ -1118,7 +1115,7 @@ object DecisionTree extends Serializable with Logging {
    * @return Array of splits.
    */
   private[tree] def findSplitsForContinuousFeature(
-      featureSamples: Array[Double],
+      featureSamples: Iterable[Double],
       metadata: DecisionTreeMetadata,
       featureIndex: Int): Array[Double] = {
     require(metadata.isContinuous(featureIndex),
@@ -1128,8 +1125,9 @@ object DecisionTree extends Serializable with Logging {
       val numSplits = metadata.numSplits(featureIndex)
 
       // get count for each distinct value
-      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
-        m + ((x, m.getOrElse(x, 0) + 1))
+      val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+        case ((m, cnt), x) =>
+          (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
       }
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -1140,7 +1138,7 @@ object DecisionTree extends Serializable with Logging {
         valueCounts.map(_._1)
       } else {
         // stride between splits
-        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        val stride: Double = numSamples.toDouble / (numSplits + 1)
         logDebug("stride = " + stride)
 
         // iterate `valueCount` to find splits
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 7985ed4b4c0fa..c745e9f8dbed5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -73,25 +73,33 @@ private[spark] class DTStatsAggregator(
    * Flat array of elements.
    * Index for start of stats for a (feature, bin) is:
    *   index = featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features,
-   *       the left child stats have binIndex in [0, numBins(featureIndex) / 2))
-   *       and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
    */
   private val allStats: Array[Double] = new Array[Double](allStatsSize)
 
+  /**
+   * Array of parent node sufficient stats.
+   *
+   * Note: this is necessary because stats for the parent node are not available
+   *       on the first iteration of tree learning.
+   */
+  private val parentStats: Array[Double] = new Array[Double](statsSize)
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param featureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   * @param featureOffset  This is a pre-computed (node, feature) offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
     impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
   }
 
+  /**
+   * Get an [[ImpurityCalculator]] for the parent node.
+   */
+  def getParentImpurityCalculator(): ImpurityCalculator = {
+    impurityAggregator.getCalculator(parentStats, 0)
+  }
+
   /**
    * Update the stats for a given (feature, bin) for ordered features, using the given label.
    */
@@ -100,14 +108,18 @@ private[spark] class DTStatsAggregator(
     impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
+  /**
+   * Update the parent node stats using the given label.
+   */
+  def updateParent(label: Double, instanceWeight: Double): Unit = {
+    impurityAggregator.update(parentStats, 0, label, instanceWeight)
+  }
+
   /**
    * Faster version of [[update]].
    * Update the stats for a given (feature, bin), using the given label.
-   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def featureUpdate(
       featureOffset: Int,
@@ -124,22 +136,10 @@ private[spark] class DTStatsAggregator(
    */
   def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
 
-  /**
-   * Pre-compute feature offset for use with [[featureUpdate]].
-   * For unordered features only.
-   */
-  def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
-    val baseOffset = featureOffsets(featureIndex)
-    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
-  }
-
   /**
    * For a given feature, merge the stats for two bins.
-   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    * @param binIndex  The other bin is merged into this bin.
    * @param otherBinIndex  This bin is not modified.
    */
@@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
       allStats(i) += other.allStats(i)
       i += 1
     }
+
+    require(statsSize == other.statsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
+        s"stats vectors. This aggregator's parent stats are length $statsSize, " +
+        s"but the other is ${other.statsSize}.")
+    var j = 0
+    while (j < statsSize) {
+      parentStats(j) += other.parentStats(j)
+      j += 1
+    }
+
     this
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index df13d291ca396..4f27dc44eff4d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
 
   /**
    * Number of splits for the given feature.
-   * For unordered features, there are 2 bins per split.
+   * For unordered features, there is 1 bin per split.
    * For ordered features, there is 1 more bin than split.
    */
   def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
-    numBins(featureIndex) >> 1
+    numBins(featureIndex)
   } else {
     numBins(featureIndex) - 1
   }
@@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
    * there are math.pow(2, arity - 1) - 1 such splits.
    * Each split has 2 corresponding bins.
    */
-  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+  def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
 
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index fbbec1197404a..dc7e969f7b5e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -173,7 +173,6 @@ private[spark] class NodeIdCache(
   }
 }
 
-@DeveloperApi
 private[spark] object NodeIdCache {
   /**
    * Initialize the node Id cache with initial node Id values.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 73df6b054a8ce..13aff110079ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
     new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index f21845b21a802..39c7f9c3be8ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
     new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index b2c6e2bba43b6..65f0163ec6059 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 09017d482a73c..92d74a1b83341 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
   def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
     new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 091a0462c204f..db1e27bf70b8d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -79,7 +79,6 @@ private[spark] object InformationGainStats {
 }
 
 /**
- * :: DeveloperApi ::
  * Impurity statistics for each split
  * @param gain information gain value
  * @param impurity current node impurity
@@ -89,7 +88,6 @@ private[spark] object InformationGainStats {
  * @param valid whether the current split satisfies minimum info gain or
  *              minimum number of instances per node
  */
-@DeveloperApi
 private[spark] class ImpurityStats(
     val gain: Double,
     val impurity: Double,
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 536f0dc58ff38..e160a5a47e304 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -18,7 +18,6 @@
 package org.apache.spark.ml.classification;
 
 import java.io.Serializable;
-import java.lang.Math;
 import java.util.List;
 
 import org.junit.After;
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index d493a7fcec7e1..00f4476841af1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -48,7 +48,8 @@ public void setUp() {
         jsql = new SQLContext(jsc);
         int nPoints = 3;
 
-        // The following coefficients and xMean/xVariance are computed from iris dataset with lambda=0.2.
+        // The following coefficients and xMean/xVariance are computed from iris dataset with
+        // lambda=0.2.
         // As a result, we are drawing samples from probability distribution of an actual model.
         double[] coefficients = {
                 -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 5812037dee90e..bdcbde5e26223 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -63,7 +63,8 @@ public void javaCompatibilityTest() {
       RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
     );
     StructType schema = new StructType(new StructField[] {
-      new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
+      new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false,
+                      Metadata.empty())
     });
     Dataset dataset = jsql.createDataFrame(data, schema);
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 16e565d8b588b..e1b269b5b681f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new RFormula())
   }
@@ -252,4 +253,41 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
         new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
     assert(attrs === expectedAttrs)
   }
+
+  test("read/write: RFormula") {
+    val rFormula = new RFormula()
+      .setFormula("id ~ a:b")
+      .setFeaturesCol("myFeatures")
+      .setLabelCol("myLabels")
+
+    testDefaultReadWrite(rFormula)
+  }
+
+  test("read/write: RFormulaModel") {
+    def checkModelData(model: RFormulaModel, model2: RFormulaModel): Unit = {
+      assert(model.uid === model2.uid)
+
+      assert(model.resolvedFormula.label === model2.resolvedFormula.label)
+      assert(model.resolvedFormula.terms === model2.resolvedFormula.terms)
+      assert(model.resolvedFormula.hasIntercept === model2.resolvedFormula.hasIntercept)
+
+      assert(model.pipelineModel.uid === model2.pipelineModel.uid)
+
+      model.pipelineModel.stages.zip(model2.pipelineModel.stages).foreach {
+        case (transformer1, transformer2) =>
+          assert(transformer1.uid === transformer2.uid)
+          assert(transformer1.params === transformer2.params)
+      }
+    }
+
+    val dataset = sqlContext.createDataFrame(
+      Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+    ).toDF("id", "a", "b")
+
+    val rFormula = new RFormula().setFormula("id ~ a:b")
+
+    val model = rFormula.fit(dataset)
+    val newModel = testDefaultReadWrite(model)
+    checkModelData(model, newModel)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 748868554fe65..a3366c0e5934c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -268,15 +268,10 @@ class ParamsSuite extends SparkFunSuite {
       solver.getParam("abc")
     }
 
-    intercept[IllegalArgumentException] {
-      solver.validateParams()
-    }
-    solver.copy(ParamMap(inputCol -> "input")).validateParams()
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
     assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
-    solver.validateParams()
     intercept[IllegalArgumentException] {
       ParamMap(maxIter -> -10)
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 9d23547f28447..7d990ce0bcfd8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -34,10 +34,5 @@ class TestParams(override val uid: String) extends Params with HasHandleInvalid
 
   def clearMaxIter(): this.type = clear(maxIter)
 
-  override def validateParams(): Unit = {
-    super.validateParams()
-    require(isDefined(inputCol))
-  }
-
   override def copy(extra: ParamMap): TestParams = defaultCopy(extra)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 56545de14bd30..7af3c6d6ede47 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLog
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
 
 class CrossValidatorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -96,7 +96,7 @@ class CrossValidatorSuite
     assert(cvModel2.avgMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import CrossValidatorSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
@@ -110,12 +110,12 @@ class CrossValidatorSuite
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
 
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 
@@ -311,14 +311,13 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 5fb80091d0b4b..cf8dcefebc3aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -83,7 +83,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(cvModel2.validationMetrics.length === lrParamMaps.length)
   }
 
-  test("validateParams should check estimatorParamMaps") {
+  test("transformSchema should check estimatorParamMaps") {
     import TrainValidationSplitSuite._
 
     val est = new MyEstimator("est")
@@ -97,12 +97,12 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
       .setEstimatorParamMaps(paramMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
-    cv.validateParams() // This should pass.
+    cv.transformSchema(new StructType()) // This should pass.
 
     val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
     cv.setEstimatorParamMaps(invalidParamMaps)
     intercept[IllegalArgumentException] {
-      cv.validateParams()
+      cv.transformSchema(new StructType())
     }
   }
 }
@@ -113,14 +113,13 @@ object TrainValidationSplitSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
 
-    override def validateParams(): Unit = require($(inputCol).nonEmpty)
-
     override def fit(dataset: DataFrame): MyModel = {
       throw new UnsupportedOperationException
     }
 
     override def transformSchema(schema: StructType): StructType = {
-      throw new UnsupportedOperationException
+      require($(inputCol).nonEmpty)
+      schema
     }
 
     override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf527c8a..89b64fce96ebf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(bins.length === 2)
     assert(splits(0).length === 3)
     assert(bins(0).length === 0)
+    assert(metadata.numSplits(0) === 3)
+    assert(metadata.numBins(0) === 3)
+    assert(metadata.numSplits(1) === 3)
+    assert(metadata.numBins(1) === 3)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b38eec34a08b5..68e9c50d60f6a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -300,12 +300,6 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
       ) ++ Seq(
         // [SPARK-13244][SQL] Migrates DataFrame to Dataset
-        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.apply"),
-        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.toDF"),
-        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.copy"),
-        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.copy$default$1"),
-        ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.df$1"),
-        ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.this"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"),
@@ -315,6 +309,14 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"),
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"),
+        ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
 
         ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
@@ -567,6 +569,9 @@ object MimaExcludes {
             if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false
           case _ => true
         }
+      ) ++ Seq(
+        // [SPARK-13990] Automatically pick serializer when caching RDDs
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock")
       )
     case v if v.startsWith("1.6") =>
       Seq(
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 5025493c42c38..3182faac0de0f 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2360,7 +2360,7 @@ def explainedVariance(self):
 
 
 @inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
@@ -2376,7 +2376,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
     ...     (0.0, 0.0, "a")
     ... ], ["y", "x", "s"])
     >>> rf = RFormula(formula="y ~ x + s")
-    >>> rf.fit(df).transform(df).show()
+    >>> model = rf.fit(df)
+    >>> model.transform(df).show()
     +---+---+---+---------+-----+
     |  y|  x|  s| features|label|
     +---+---+---+---------+-----+
@@ -2394,6 +2395,29 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
     |0.0|0.0|  a|   [0.0]|  0.0|
     +---+---+---+--------+-----+
     ...
+    >>> rFormulaPath = temp_path + "/rFormula"
+    >>> rf.save(rFormulaPath)
+    >>> loadedRF = RFormula.load(rFormulaPath)
+    >>> loadedRF.getFormula() == rf.getFormula()
+    True
+    >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()
+    True
+    >>> loadedRF.getLabelCol() == rf.getLabelCol()
+    True
+    >>> modelPath = temp_path + "/rFormulaModel"
+    >>> model.save(modelPath)
+    >>> loadedModel = RFormulaModel.load(modelPath)
+    >>> loadedModel.uid == model.uid
+    True
+    >>> loadedModel.transform(df).show()
+    +---+---+---+---------+-----+
+    |  y|  x|  s| features|label|
+    +---+---+---+---------+-----+
+    |1.0|1.0|  a|[1.0,1.0]|  1.0|
+    |0.0|2.0|  b|[2.0,0.0]|  0.0|
+    |0.0|0.0|  a|[0.0,1.0]|  0.0|
+    +---+---+---+---------+-----+
+    ...
 
     .. versionadded:: 1.5.0
     """
@@ -2439,7 +2463,7 @@ def _create_model(self, java_model):
         return RFormulaModel(java_model)
 
 
-class RFormulaModel(JavaModel):
+class RFormulaModel(JavaModel, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 0f7b5e9b9e1a9..37dcb23b6776b 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -34,10 +34,15 @@ class JavaWrapper(Params):
 
     __metaclass__ = ABCMeta
 
-    #: The wrapped Java companion object. Subclasses should initialize
-    #: it properly. The param values in the Java object should be
-    #: synced with the Python wrapper in fit/transform/evaluate/copy.
-    _java_obj = None
+    def __init__(self):
+        """
+        Initialize the wrapped java object to None
+        """
+        super(JavaWrapper, self).__init__()
+        #: The wrapped Java companion object. Subclasses should initialize
+        #: it properly. The param values in the Java object should be
+        #: synced with the Python wrapper in fit/transform/evaluate/copy.
+        self._java_obj = None
 
     @staticmethod
     def _new_java_obj(java_class, *args):
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 438662bb157f0..bae9e69df8e2b 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -162,6 +162,14 @@ def json(self, path, schema=None):
                 (e.g. 00012)
             * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \
                 of all character using backslash quoting mechanism
+            *  ``mode`` (default ``PERMISSIVE``): allows a mode for dealing with corrupt records \
+                during parsing.
+                *  ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+                  record and puts the malformed string into a new field configured by \
+                 ``spark.sql.columnNameOfCorruptRecord``. When a schema is set by user, it sets \
+                 ``null`` for extra fields.
+                *  ``DROPMALFORMED`` : ignores the whole corrupted records.
+                *  ``FAILFAST`` : throws an exception when it meets corrupted records.
 
         >>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
         >>> df1.dtypes
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index cbcccb11f14ae..6b9aa5071e1d0 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
         |import org.apache.spark.sql.Encoder
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
-        |val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
+        |val simpleSum = new Aggregator[Int, Int, Int] {
         |  def zero: Int = 0                     // The initial value.
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
@@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |/** An `Aggregator` that adds up any numeric type returned by the given function. */
-        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
         |  val numeric = implicitly[Numeric[N]]
         |  override def zero: N = numeric.zero
         |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 6bee880640ced..f148a6df47607 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
     // We need to use local-cluster to test this case.
     val output = runInterpreter("local-cluster[1,1,1024]",
       """
-        |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-        |import sqlContext.implicits._
         |case class TestCaseClass(value: Int)
         |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
+        |
+        |// Test Dataset Serialization in the REPL
+        |Seq(TestCaseClass(1)).toDS().collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
+  test("Datasets and encoders") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |val simpleSum = new Aggregator[Int, Int, Int] {
+        |  def zero: Int = 0                     // The initial value.
+        |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
+        |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
+        |  def finish(b: Int) = b                // Return the final result.
+        |}.toColumn
+        |
+        |val ds = Seq(1, 2, 3, 4).toDS()
+        |ds.select(simpleSum).collect
       """.stripMargin)
     assertDoesNotContain("error:", output)
     assertDoesNotContain("Exception", output)
@@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
+  test("Datasets agg type-inference") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+        |class SumOf[I, N : Numeric](f: I => N) extends
+        |  org.apache.spark.sql.expressions.Aggregator[I, N, N] {
+        |  val numeric = implicitly[Numeric[N]]
+        |  override def zero: N = numeric.zero
+        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+        |  override def finish(reduction: N): N = reduction
+        |}
+        |
+        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """
@@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
     assertDoesNotContain("Exception", output)
     assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
   }
+
+  test("line wrapper only initialized once when used as encoder outer scope") {
+    val output = runInterpreter("local",
+      """
+        |val fileName = "repl-test-" + System.currentTimeMillis
+        |val tmpDir = System.getProperty("java.io.tmpdir")
+        |val file = new java.io.File(tmpDir, fileName)
+        |def createFile(): Unit = file.createNewFile()
+        |
+        |createFile();case class TestCaseClass(value: Int)
+        |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
+        |
+        |file.delete()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
 }
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index e83f8a7cd1b5c..1bf461c912b61 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -91,10 +91,17 @@ fromClause
 joinSource
 @init { gParent.pushMsg("join source", state); }
 @after { gParent.popMsg(state); }
-    : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )*
+    : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )*
     | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
     ;
 
+joinCond
+@init { gParent.pushMsg("join expression list", state); }
+@after { gParent.popMsg(state); }
+    : KW_ON! expression
+    | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
+    ;
+
 uniqueJoinSource
 @init { gParent.pushMsg("unique join source", state); }
 @after { gParent.popMsg(state); }
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 1db3aed65815d..f0c236859ddca 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -387,6 +387,7 @@ TOK_SETCONFIG;
 TOK_DFS;
 TOK_ADDFILE;
 TOK_ADDJAR;
+TOK_USING;
 }
 
 
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 0ad0f4976c77a..d85147e961fa8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -171,7 +171,7 @@ private static final class RowComparator extends RecordComparator {
     private final UnsafeRow row1;
     private final UnsafeRow row2;
 
-    public RowComparator(Ordering ordering, int numFields) {
+    RowComparator(Ordering ordering, int numFields) {
       this.numFields = numFields;
       this.row1 = new UnsafeRow(numFields);
       this.row2 = new UnsafeRow(numFields);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 35884139b6be8..e10ab9790d767 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.analysis._
 private[spark] trait CatalystConf {
   def caseSensitiveAnalysis: Boolean
 
+  def orderByOrdinal: Boolean
+
   /**
    * Returns the [[Resolver]] for the current configuration, which can be used to determin if two
    * identifiers are equal.
@@ -43,8 +45,14 @@ object EmptyConf extends CatalystConf {
   override def caseSensitiveAnalysis: Boolean = {
     throw new UnsupportedOperationException
   }
+  override def orderByOrdinal: Boolean = {
+    throw new UnsupportedOperationException
+  }
 }
 
 /** A CatalystConf that can be used for local testing. */
-case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
+case class SimpleCatalystConf(
+    caseSensitiveAnalysis: Boolean,
+    orderByOrdinal: Boolean = true)
+  extends CatalystConf {
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 53ea3cfef6786..5951a70c4809a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatal
 import org.apache.spark.sql.catalyst.encoders.OuterScopes
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.planning.IntegerIndex
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -40,7 +41,10 @@ import org.apache.spark.sql.types._
  * references.
  */
 object SimpleAnalyzer
-  extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(true))
+  extends Analyzer(
+    EmptyCatalog,
+    EmptyFunctionRegistry,
+    new SimpleCatalystConf(caseSensitiveAnalysis = true))
 
 /**
  * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
@@ -76,6 +80,7 @@ class Analyzer(
       EliminateUnions),
     Batch("Resolution", fixedPoint,
       ResolveRelations ::
+      ResolveStar ::
       ResolveReferences ::
       ResolveGroupingAnalytics ::
       ResolvePivot ::
@@ -87,7 +92,7 @@ class Analyzer(
       ResolveSubquery ::
       ResolveWindowOrder ::
       ResolveWindowFrame ::
-      ResolveNaturalJoin ::
+      ResolveNaturalAndUsingJoin ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
       ResolveAggregateFunctions ::
@@ -369,28 +374,83 @@ class Analyzer(
   }
 
   /**
-   * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
-   * a logical plan node's children.
+   * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
    */
-  object ResolveReferences extends Rule[LogicalPlan] {
+  object ResolveStar extends Rule[LogicalPlan] {
+
+    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+      case p: LogicalPlan if !p.childrenResolved => p
+
+      // If the projection list contains Stars, expand it.
+      case p: Project if containsStar(p.projectList) =>
+        val expanded = p.projectList.flatMap {
+          case s: Star => s.expand(p.child, resolver)
+          case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
+            UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
+          case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
+            a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)
+              .asInstanceOf[Alias] :: Nil
+          case o => o :: Nil
+        }
+        Project(projectList = expanded, p.child)
+      // If the aggregate function argument contains Stars, expand it.
+      case a: Aggregate if containsStar(a.aggregateExpressions) =>
+        val expanded = a.aggregateExpressions.flatMap {
+          case s: Star => s.expand(a.child, resolver)
+          case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
+          case o => o :: Nil
+        }.map(_.asInstanceOf[NamedExpression])
+        a.copy(aggregateExpressions = expanded)
+      // If the script transformation input contains Stars, expand it.
+      case t: ScriptTransformation if containsStar(t.input) =>
+        t.copy(
+          input = t.input.flatMap {
+            case s: Star => s.expand(t.child, resolver)
+            case o => o :: Nil
+          }
+        )
+      case g: Generate if containsStar(g.generator.children) =>
+        failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
+    }
+
+    /**
+     * Returns true if `exprs` contains a [[Star]].
+     */
+    def containsStar(exprs: Seq[Expression]): Boolean =
+      exprs.exists(_.collect { case _: Star => true }.nonEmpty)
+
     /**
-     * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
-     * rooted at each expression.
+     * Expands the matching attribute.*'s in `child`'s output.
      */
-    def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
-      exprs.flatMap {
-        case s: Star => s.expand(child, resolver)
-        case e =>
-          e.transformDown {
-            case f1: UnresolvedFunction if containsStar(f1.children) =>
-              f1.copy(children = f1.children.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              })
-          } :: Nil
+    def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
+      expr.transformUp {
+        case f1: UnresolvedFunction if containsStar(f1.children) =>
+          f1.copy(children = f1.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        case c: CreateStruct if containsStar(c.children) =>
+          c.copy(children = c.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        case c: CreateArray if containsStar(c.children) =>
+          c.copy(children = c.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        // count(*) has been replaced by count(1)
+        case o if containsStar(o.children) =>
+          failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
       }
     }
+  }
 
+  /**
+   * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
+   * a logical plan node's children.
+   */
+  object ResolveReferences extends Rule[LogicalPlan] {
     /**
      * Generate a new logical plan for the right child with different expression IDs
      * for all conflicting attributes.
@@ -442,7 +502,7 @@ class Analyzer(
           } transformUp {
             case other => other transformExpressions {
               case a: Attribute =>
-                attributeRewrites.get(a).getOrElse(a).withQualifiers(a.qualifiers)
+                attributeRewrites.get(a).getOrElse(a).withQualifier(a.qualifier)
             }
           }
           newRight
@@ -452,48 +512,6 @@ class Analyzer(
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case p: LogicalPlan if !p.childrenResolved => p
 
-      // If the projection list contains Stars, expand it.
-      case p @ Project(projectList, child) if containsStar(projectList) =>
-        Project(
-          projectList.flatMap {
-            case s: Star => s.expand(child, resolver)
-            case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
-              val newChildren = expandStarExpressions(args, child)
-              UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
-            case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
-              val newChildren = expandStarExpressions(args, child)
-              Alias(child = f.copy(children = newChildren), name)(
-                isGenerated = a.isGenerated) :: Nil
-            case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
-              val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              }
-              UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
-            case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
-              val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              }
-              UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
-            case o => o :: Nil
-          },
-          child)
-
-      case t: ScriptTransformation if containsStar(t.input) =>
-        t.copy(
-          input = t.input.flatMap {
-            case s: Star => s.expand(t.child, resolver)
-            case o => o :: Nil
-          }
-        )
-
-      // If the aggregate function argument contains Stars, expand it.
-      case a: Aggregate if containsStar(a.aggregateExpressions) =>
-        val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
-            .map(_.asInstanceOf[NamedExpression])
-        a.copy(aggregateExpressions = expanded)
-
       // To resolve duplicate expression IDs for Join and Intersect
       case j @ Join(left, right, _, _) if !j.duplicateResolved =>
         j.copy(right = dedupRight(left, right))
@@ -567,7 +585,7 @@ class Analyzer(
           if n.outerPointer.isEmpty &&
              n.cls.isMemberClass &&
              !Modifier.isStatic(n.cls.getModifiers) =>
-          val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
+          val outer = OuterScopes.getOuterScope(n.cls)
           if (outer == null) {
             throw new AnalysisException(
               s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
@@ -588,15 +606,12 @@ class Analyzer(
     def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
       AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
     }
-
-    /**
-     * Returns true if `exprs` contains a [[Star]].
-     */
-    def containsStar(exprs: Seq[Expression]): Boolean =
-      exprs.exists(_.collect { case _: Star => true }.nonEmpty)
   }
 
-  private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
+  protected[sql] def resolveExpression(
+      expr: Expression,
+      plan: LogicalPlan,
+      throws: Boolean = false) = {
     // Resolve expression in one round.
     // If throws == false or the desired attribute doesn't exist
     // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
@@ -618,13 +633,36 @@ class Analyzer(
    * clause.  This rule detects such queries and adds the required attributes to the original
    * projection, so that they will be available during sorting. Another projection is added to
    * remove these attributes after sorting.
+   *
+   * This rule also resolves the position number in sort references. This support is introduced
+   * in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting.
+   * - When the sort references are not integer but foldable expressions, ignore them.
+   * - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too.
    */
   object ResolveSortReferences extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+      case s: Sort if !s.child.resolved => s
+      // Replace the index with the related attribute for ORDER BY
+      // which is a 1-base position of the projection list.
+      case s @ Sort(orders, global, child)
+          if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
+        val newOrders = orders map {
+          case s @ SortOrder(IntegerIndex(index), direction) =>
+            if (index > 0 && index <= child.output.size) {
+              SortOrder(child.output(index - 1), direction)
+            } else {
+              throw new UnresolvedException(s,
+                s"Order/sort By position: $index does not exist " +
+                s"The Select List is indexed from 1 to ${child.output.size}")
+            }
+          case o => o
+        }
+        Sort(newOrders, global, child)
+
       // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
       case sa @ Sort(_, _, child: Aggregate) => sa
 
-      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+      case s @ Sort(order, _, child) if !s.resolved =>
         try {
           val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
           val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@@ -893,8 +931,6 @@ class Analyzer(
    */
   object ResolveGenerate extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
-        failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
       case p: Generate if !p.child.resolved || !p.generator.resolved => p
       case g: Generate if !g.resolved =>
         g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
@@ -1329,48 +1365,69 @@ class Analyzer(
   }
 
   /**
-   * Removes natural joins by calculating output columns based on output from two sides,
-   * Then apply a Project on a normal Join to eliminate natural join.
+   * Removes natural or using joins by calculating output columns based on output from two sides,
+   * Then apply a Project on a normal Join to eliminate natural or using join.
    */
-  object ResolveNaturalJoin extends Rule[LogicalPlan] {
+  object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] {
     override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+      case j @ Join(left, right, UsingJoin(joinType, usingCols), condition)
+          if left.resolved && right.resolved && j.duplicateResolved =>
+        // Resolve the column names referenced in using clause from both the legs of join.
+        val lCols = usingCols.flatMap(col => left.resolveQuoted(col.name, resolver))
+        val rCols = usingCols.flatMap(col => right.resolveQuoted(col.name, resolver))
+        if ((lCols.length == usingCols.length) && (rCols.length == usingCols.length)) {
+          val joinNames = lCols.map(exp => exp.name)
+          commonNaturalJoinProcessing(left, right, joinType, joinNames, None)
+        } else {
+          j
+        }
       case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
         // find common column names from both sides
         val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
-        val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
-        val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
-        val joinPairs = leftKeys.zip(rightKeys)
-
-        // Add joinPairs to joinConditions
-        val newCondition = (condition ++ joinPairs.map {
-          case (l, r) => EqualTo(l, r)
-        }).reduceOption(And)
-
-        // columns not in joinPairs
-        val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
-        val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
-
-        // the output list looks like: join keys, columns from left, columns from right
-        val projectList = joinType match {
-          case LeftOuter =>
-            leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
-          case RightOuter =>
-            rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
-          case FullOuter =>
-            // in full outer join, joinCols should be non-null if there is.
-            val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
-            joinedCols ++
-              lUniqueOutput.map(_.withNullability(true)) ++
-              rUniqueOutput.map(_.withNullability(true))
-          case Inner =>
-            rightKeys ++ lUniqueOutput ++ rUniqueOutput
-          case _ =>
-            sys.error("Unsupported natural join type " + joinType)
-        }
-        // use Project to trim unnecessary fields
-        Project(projectList, Join(left, right, joinType, newCondition))
+        commonNaturalJoinProcessing(left, right, joinType, joinNames, condition)
     }
   }
+
+  private def commonNaturalJoinProcessing(
+     left: LogicalPlan,
+     right: LogicalPlan,
+     joinType: JoinType,
+     joinNames: Seq[String],
+     condition: Option[Expression]) = {
+    val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
+    val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
+    val joinPairs = leftKeys.zip(rightKeys)
+
+    val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)
+
+    // columns not in joinPairs
+    val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
+    val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
+
+    // the output list looks like: join keys, columns from left, columns from right
+    val projectList = joinType match {
+      case LeftOuter =>
+        leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
+      case LeftSemi =>
+        leftKeys ++ lUniqueOutput
+      case RightOuter =>
+        rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
+      case FullOuter =>
+        // in full outer join, joinCols should be non-null if there is.
+        val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
+        joinedCols ++
+          lUniqueOutput.map(_.withNullability(true)) ++
+          rUniqueOutput.map(_.withNullability(true))
+      case Inner =>
+        leftKeys ++ lUniqueOutput ++ rUniqueOutput
+      case _ =>
+        sys.error("Unsupported natural join type " + joinType)
+    }
+    // use Project to trim unnecessary fields
+    Project(projectList, Join(left, right, joinType, newCondition))
+  }
+
+
 }
 
 /**
@@ -1416,8 +1473,7 @@ object CleanupAliases extends Rule[LogicalPlan] {
 
   def trimNonTopLevelAliases(e: Expression): Expression = e match {
     case a: Alias =>
-      Alias(trimAliases(a.child), a.name)(
-        a.exprId, a.qualifiers, a.explicitMetadata, a.isGenerated)
+      a.withNewChildren(trimAliases(a.child) :: Nil)
     case other => trimAliases(other)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index 52b284b757df5..2f0a4dbc107aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -101,13 +101,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog {
     if (table == null) {
       throw new AnalysisException("Table not found: " + tableName)
     }
-    val tableWithQualifiers = SubqueryAlias(tableName, table)
+    val qualifiedTable = SubqueryAlias(tableName, table)
 
     // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
     // properly qualified with this alias.
     alias
-      .map(a => SubqueryAlias(a, tableWithQualifiers))
-      .getOrElse(tableWithQualifiers)
+      .map(a => SubqueryAlias(a, qualifiedTable))
+      .getOrElse(qualifiedTable)
   }
 
   override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
@@ -149,11 +149,11 @@ trait OverrideCatalog extends Catalog {
     getOverriddenTable(tableIdent) match {
       case Some(table) =>
         val tableName = getTableName(tableIdent)
-        val tableWithQualifiers = SubqueryAlias(tableName, table)
+        val qualifiedTable = SubqueryAlias(tableName, table)
 
         // If an alias was specified by the lookup, wrap the plan in a sub-query so that attributes
         // are properly qualified with this alias.
-        alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+        alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
 
       case None => super.lookupRelation(tableIdent, alias)
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1e430c1fbbdf0..1d1e892e32cd3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.UsingJoin
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -109,6 +110,12 @@ trait CheckAnalysis {
               s"filter expression '${f.condition.sql}' " +
                 s"of type ${f.condition.dataType.simpleString} is not a boolean.")
 
+          case j @ Join(_, _, UsingJoin(_, cols), _) =>
+            val from = operator.inputSet.map(_.name).mkString(", ")
+            failAnalysis(
+              s"using columns [${cols.mkString(",")}] " +
+                s"can not be resolved given input columns: [$from] ")
+
           case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType =>
             failAnalysis(
               s"join condition '${condition.sql}' " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 01afa01ae95c5..9518309fbf8ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -59,12 +59,12 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
   override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
   override def dataType: DataType = throw new UnresolvedException(this, "dataType")
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
-  override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+  override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
   override lazy val resolved = false
 
   override def newInstance(): UnresolvedAttribute = this
   override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
-  override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
+  override def withQualifier(newQualifier: Option[String]): UnresolvedAttribute = this
   override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
 
   override def toString: String = s"'$name"
@@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression {
   override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
   override def dataType: DataType = throw new UnresolvedException(this, "dataType")
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
-  override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+  override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
   override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
   override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
   override lazy val resolved = false
@@ -188,7 +188,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
       case None => input.output
       // If there is a table, pick out attributes that are part of this table.
       case Some(t) => if (t.size == 1) {
-        input.output.filter(_.qualifiers.exists(resolver(_, t.head)))
+        input.output.filter(_.qualifier.exists(resolver(_, t.head)))
       } else {
         List()
       }
@@ -243,7 +243,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
 
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
 
-  override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+  override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
 
   override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
 
@@ -298,7 +298,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
   extends UnaryExpression with NamedExpression with Unevaluable {
 
   override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
-  override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+  override def qualifier: Option[String] = throw new UnresolvedException(this, "qualifier")
   override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
   override def dataType: DataType = throw new UnresolvedException(this, "dataType")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 4dec0429bd1fc..3ac2bcf7e8d03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -206,10 +206,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
       } else {
         tempTables.get(name.table)
       }
-    val tableWithQualifiers = SubqueryAlias(name.table, relation)
+    val qualifiedTable = SubqueryAlias(name.table, relation)
     // If an alias was specified by the lookup, wrap the plan in a subquery so that
     // attributes are properly qualified with this alias.
-    alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
+    alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable)
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 58f6d0eb9e929..918233ddcdaf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -282,8 +282,14 @@ case class ExpressionEncoder[T](
     // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
     // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
     // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
-    // we unbound it by the given `schema` and propagate the actual type to `GetStructField`.
-    val unbound = fromRowExpression transform {
+    // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
+    // we resolve the `fromRowExpression`.
+    val resolved = SimpleAnalyzer.resolveExpression(
+      fromRowExpression,
+      LocalRelation(schema),
+      throws = true)
+
+    val unbound = resolved transform {
       case b: BoundReference => schema(b.ordinal)
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index a753b187bcd32..c047e96463544 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -21,6 +21,8 @@ import java.util.concurrent.ConcurrentMap
 
 import com.google.common.collect.MapMaker
 
+import org.apache.spark.util.Utils
+
 object OuterScopes {
   @transient
   lazy val outerScopes: ConcurrentMap[String, AnyRef] =
@@ -28,7 +30,7 @@ object OuterScopes {
 
   /**
    * Adds a new outer scope to this context that can be used when instantiating an `inner class`
-   * during deserialialization. Inner classes are created when a case class is defined in the
+   * during deserialization. Inner classes are created when a case class is defined in the
    * Spark REPL and registering the outer scope that this class was defined in allows us to create
    * new instances on the spark executors.  In normal use, users should not need to call this
    * function.
@@ -39,4 +41,47 @@ object OuterScopes {
   def addOuterScope(outer: AnyRef): Unit = {
     outerScopes.putIfAbsent(outer.getClass.getName, outer)
   }
+
+  def getOuterScope(innerCls: Class[_]): AnyRef = {
+    assert(innerCls.isMemberClass)
+    val outerClassName = innerCls.getDeclaringClass.getName
+    val outer = outerScopes.get(outerClassName)
+    if (outer == null) {
+      outerClassName match {
+        // If the outer class is generated by REPL, users don't need to register it as it has
+        // only one instance and there is a way to retrieve it: get the `$read` object, call the
+        // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`
+        // method multiply times to get the single instance of the inner most `$iw` class.
+        case REPLClass(baseClassName) =>
+          val objClass = Utils.classForName(baseClassName + "$")
+          val objInstance = objClass.getField("MODULE$").get(null)
+          val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
+          val baseClass = Utils.classForName(baseClassName)
+
+          var getter = iwGetter(baseClass)
+          var obj = baseInstance
+          while (getter != null) {
+            obj = getter.invoke(obj)
+            getter = iwGetter(getter.getReturnType)
+          }
+
+          outerScopes.putIfAbsent(outerClassName, obj)
+          obj
+        case _ => null
+      }
+    } else {
+      outer
+    }
+  }
+
+  private def iwGetter(cls: Class[_]) = {
+    try {
+      cls.getMethod("$iw")
+    } catch {
+      case _: NoSuchMethodException => null
+    }
+  }
+
+  // The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw`
+  private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 0d249a118cfa6..c1fd23f28d6b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -65,7 +65,9 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
       val oev = ctx.currentVars(ordinal)
       ev.isNull = oev.isNull
       ev.value = oev.value
-      oev.code
+      val code = oev.code
+      oev.code = ""
+      code
     } else if (nullable) {
       s"""
         boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b7b2b9a438dcf..6875915f79b15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -232,7 +232,7 @@ trait Unevaluable extends Expression {
  * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
  */
 trait NonSQLExpression extends Expression {
-  override def sql: String = {
+  final override def sql: String = {
     transform {
       case a: Attribute => new PrettyAttribute(a)
     }.toString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 44cdc8d8812e6..c06dcc98674fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -110,7 +110,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
 
   override def dataType: DataType = childSchema(ordinal).dataType
   override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable
-  override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}"
+
+  override def toString: String = {
+    val fieldName = if (resolved) childSchema(ordinal).name else s"_$ordinal"
+    s"$child.${name.getOrElse(fieldName)}"
+  }
+
   override def sql: String =
     child.sql + s".${quoteIdentifier(name.getOrElse(childSchema(ordinal).name))}"
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 271ef33090980..a5b5758167276 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -61,10 +61,10 @@ trait NamedExpression extends Expression {
    * multiple qualifiers, it is possible that there are other possible way to refer to this
    * attribute.
    */
-  def qualifiedName: String = (qualifiers.headOption.toSeq :+ name).mkString(".")
+  def qualifiedName: String = (qualifier.toSeq :+ name).mkString(".")
 
   /**
-   * All possible qualifiers for the expression.
+   * Optional qualifier for the expression.
    *
    * For now, since we do not allow using original table name to qualify a column name once the
    * table is aliased, this can only be:
@@ -73,7 +73,7 @@ trait NamedExpression extends Expression {
    *    e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation.
    * 2. Single element: either the table name or the alias name of the table.
    */
-  def qualifiers: Seq[String]
+  def qualifier: Option[String]
 
   def toAttribute: Attribute
 
@@ -102,7 +102,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
   override def references: AttributeSet = AttributeSet(this)
 
   def withNullability(newNullability: Boolean): Attribute
-  def withQualifiers(newQualifiers: Seq[String]): Attribute
+  def withQualifier(newQualifier: Option[String]): Attribute
   def withName(newName: String): Attribute
 
   override def toAttribute: Attribute = this
@@ -122,7 +122,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
  * @param name The name to be associated with the result of computing [[child]].
  * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this
  *               alias. Auto-assigned if left blank.
- * @param qualifiers A list of strings that can be used to referred to this attribute in a fully
+ * @param qualifier An optional string that can be used to referred to this attribute in a fully
  *                   qualified way. Consider the examples tableName.name, subQueryAlias.name.
  *                   tableName and subQueryAlias are possible qualifiers.
  * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
@@ -130,7 +130,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
  */
 case class Alias(child: Expression, name: String)(
     val exprId: ExprId = NamedExpression.newExprId,
-    val qualifiers: Seq[String] = Nil,
+    val qualifier: Option[String] = None,
     val explicitMetadata: Option[Metadata] = None,
     override val isGenerated: java.lang.Boolean = false)
   extends UnaryExpression with NamedExpression {
@@ -158,12 +158,12 @@ case class Alias(child: Expression, name: String)(
 
   def newInstance(): NamedExpression =
     Alias(child, name)(
-      qualifiers = qualifiers, explicitMetadata = explicitMetadata, isGenerated = isGenerated)
+      qualifier = qualifier, explicitMetadata = explicitMetadata, isGenerated = isGenerated)
 
   override def toAttribute: Attribute = {
     if (resolved) {
       AttributeReference(name, child.dataType, child.nullable, metadata)(
-        exprId, qualifiers, isGenerated)
+        exprId, qualifier, isGenerated)
     } else {
       UnresolvedAttribute(name)
     }
@@ -172,20 +172,19 @@ case class Alias(child: Expression, name: String)(
   override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix"
 
   override protected final def otherCopyArgs: Seq[AnyRef] = {
-    exprId :: qualifiers :: explicitMetadata :: isGenerated :: Nil
+    exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil
   }
 
   override def equals(other: Any): Boolean = other match {
     case a: Alias =>
-      name == a.name && exprId == a.exprId && child == a.child && qualifiers == a.qualifiers &&
+      name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
         explicitMetadata == a.explicitMetadata
     case _ => false
   }
 
   override def sql: String = {
-    val qualifiersString =
-      if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
-    s"${child.sql} AS $qualifiersString${quoteIdentifier(name)}"
+    val qualifierPrefix = qualifier.map(_ + ".").getOrElse("")
+    s"${child.sql} AS $qualifierPrefix${quoteIdentifier(name)}"
   }
 }
 
@@ -198,9 +197,9 @@ case class Alias(child: Expression, name: String)(
  * @param metadata The metadata of this attribute.
  * @param exprId A globally unique id used to check if different AttributeReferences refer to the
  *               same attribute.
- * @param qualifiers A list of strings that can be used to referred to this attribute in a fully
- *                   qualified way. Consider the examples tableName.name, subQueryAlias.name.
- *                   tableName and subQueryAlias are possible qualifiers.
+ * @param qualifier An optional string that can be used to referred to this attribute in a fully
+ *                  qualified way. Consider the examples tableName.name, subQueryAlias.name.
+ *                  tableName and subQueryAlias are possible qualifiers.
  * @param isGenerated A flag to indicate if this reference is generated by Catalyst
  */
 case class AttributeReference(
@@ -209,7 +208,7 @@ case class AttributeReference(
     nullable: Boolean = true,
     override val metadata: Metadata = Metadata.empty)(
     val exprId: ExprId = NamedExpression.newExprId,
-    val qualifiers: Seq[String] = Nil,
+    val qualifier: Option[String] = None,
     override val isGenerated: java.lang.Boolean = false)
   extends Attribute with Unevaluable {
 
@@ -221,7 +220,7 @@ case class AttributeReference(
   override def equals(other: Any): Boolean = other match {
     case ar: AttributeReference =>
       name == ar.name && dataType == ar.dataType && nullable == ar.nullable &&
-        metadata == ar.metadata && exprId == ar.exprId && qualifiers == ar.qualifiers
+        metadata == ar.metadata && exprId == ar.exprId && qualifier == ar.qualifier
     case _ => false
   }
 
@@ -242,13 +241,13 @@ case class AttributeReference(
     h = h * 37 + nullable.hashCode()
     h = h * 37 + metadata.hashCode()
     h = h * 37 + exprId.hashCode()
-    h = h * 37 + qualifiers.hashCode()
+    h = h * 37 + qualifier.hashCode()
     h
   }
 
   override def newInstance(): AttributeReference =
     AttributeReference(name, dataType, nullable, metadata)(
-      qualifiers = qualifiers, isGenerated = isGenerated)
+      qualifier = qualifier, isGenerated = isGenerated)
 
   /**
    * Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -257,7 +256,7 @@ case class AttributeReference(
     if (nullable == newNullability) {
       this
     } else {
-      AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers, isGenerated)
+      AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier, isGenerated)
     }
   }
 
@@ -265,18 +264,18 @@ case class AttributeReference(
     if (name == newName) {
       this
     } else {
-      AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifiers, isGenerated)
+      AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, isGenerated)
     }
   }
 
   /**
-   * Returns a copy of this [[AttributeReference]] with new qualifiers.
+   * Returns a copy of this [[AttributeReference]] with new qualifier.
    */
-  override def withQualifiers(newQualifiers: Seq[String]): AttributeReference = {
-    if (newQualifiers.toSet == qualifiers.toSet) {
+  override def withQualifier(newQualifier: Option[String]): AttributeReference = {
+    if (newQualifier == qualifier) {
       this
     } else {
-      AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers, isGenerated)
+      AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, isGenerated)
     }
   }
 
@@ -284,12 +283,12 @@ case class AttributeReference(
     if (exprId == newExprId) {
       this
     } else {
-      AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers, isGenerated)
+      AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, isGenerated)
     }
   }
 
   override protected final def otherCopyArgs: Seq[AnyRef] = {
-    exprId :: qualifiers :: isGenerated :: Nil
+    exprId :: qualifier :: isGenerated :: Nil
   }
 
   override def toString: String = s"$name#${exprId.id}$typeSuffix"
@@ -299,9 +298,8 @@ case class AttributeReference(
   override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
 
   override def sql: String = {
-    val qualifiersString =
-      if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".")
-    s"$qualifiersString${quoteIdentifier(name)}"
+    val qualifierPrefix = qualifier.map(_ + ".").getOrElse("")
+    s"$qualifierPrefix${quoteIdentifier(name)}"
   }
 }
 
@@ -326,10 +324,10 @@ case class PrettyAttribute(
   override def withNullability(newNullability: Boolean): Attribute =
     throw new UnsupportedOperationException
   override def newInstance(): Attribute = throw new UnsupportedOperationException
-  override def withQualifiers(newQualifiers: Seq[String]): Attribute =
+  override def withQualifier(newQualifier: Option[String]): Attribute =
     throw new UnsupportedOperationException
   override def withName(newName: String): Attribute = throw new UnsupportedOperationException
-  override def qualifiers: Seq[String] = throw new UnsupportedOperationException
+  override def qualifier: Option[String] = throw new UnsupportedOperationException
   override def exprId: ExprId = throw new UnsupportedOperationException
   override def nullable: Boolean = throw new UnsupportedOperationException
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d0e5859d2702e..41e8dc0f46746 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -87,6 +87,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
       SimplifyConditionals,
       RemoveDispensableExpressions,
       PruneFilters,
+      EliminateSorts,
       SimplifyCasts,
       SimplifyCaseConversionExpressions,
       EliminateSerialization) ::
@@ -825,6 +826,17 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
   }
 }
 
+/**
+ * Removes no-op SortOrder from Sort
+ */
+object EliminateSorts  extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
+      val newOrders = orders.filterNot(_.child.foldable)
+      if (newOrders.isEmpty) child else s.copy(order = newOrders)
+  }
+}
+
 /**
  * Removes filters that can be evaluated trivially.  This can be done through the following ways:
  * 1) by eliding the filter for cases where it will always evaluate to `true`.
@@ -1133,6 +1145,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
             reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
         case FullOuter => f // DO Nothing for Full Outer Join
         case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
+        case UsingJoin(_, _) => sys.error("Untransformed Using join node")
       }
 
     // push down the join filter into sub query scanning if applicable
@@ -1168,6 +1181,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
           Join(newLeft, newRight, LeftOuter, newJoinCond)
         case FullOuter => f
         case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
+        case UsingJoin(_, _) => sys.error("Untransformed Using join node")
       }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
index 7d5a46873c217..c188c5b108491 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
@@ -419,30 +419,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
           sys.error(s"Unsupported join operation: $other")
         }
 
-        val joinType = joinToken match {
-          case "TOK_JOIN" => Inner
-          case "TOK_CROSSJOIN" => Inner
-          case "TOK_RIGHTOUTERJOIN" => RightOuter
-          case "TOK_LEFTOUTERJOIN" => LeftOuter
-          case "TOK_FULLOUTERJOIN" => FullOuter
-          case "TOK_LEFTSEMIJOIN" => LeftSemi
-          case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
-          case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
-          case "TOK_NATURALJOIN" => NaturalJoin(Inner)
-          case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
-          case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
-          case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
-        }
+        val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)
+
         Join(nodeToRelation(relation1),
           nodeToRelation(relation2),
           joinType,
-          other.headOption.map(nodeToExpr))
-
+          joinCondition)
       case _ =>
         noParseRule("Relation", node)
     }
   }
 
+  protected def getJoinInfo(
+     joinToken: String,
+     joinConditionToken: Seq[ASTNode],
+     node: ASTNode): (JoinType, Option[Expression]) = {
+    val joinType = joinToken match {
+      case "TOK_JOIN" => Inner
+      case "TOK_CROSSJOIN" => Inner
+      case "TOK_RIGHTOUTERJOIN" => RightOuter
+      case "TOK_LEFTOUTERJOIN" => LeftOuter
+      case "TOK_FULLOUTERJOIN" => FullOuter
+      case "TOK_LEFTSEMIJOIN" => LeftSemi
+      case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
+      case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
+      case "TOK_NATURALJOIN" => NaturalJoin(Inner)
+      case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
+      case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
+      case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
+    }
+
+    joinConditionToken match {
+      case Token("TOK_USING", columnList :: Nil) :: Nil =>
+        val colNames = columnList.children.collect {
+          case Token(name, Nil) => UnresolvedAttribute(name)
+        }
+        (UsingJoin(joinType, colNames), None)
+      /* Join expression specified using ON clause */
+      case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
+    }
+  }
+
   protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
     case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
       SortOrder(nodeToExpr(sortExpr), Ascending)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 681f06ed1ecf3..ada842477116c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.IntegerType
 
 /**
  * A pattern that matches any number of project or filter operations on top of another relational
@@ -79,12 +80,12 @@ object PhysicalOperation extends PredicateHelper {
     expr.transform {
       case a @ Alias(ref: AttributeReference, name) =>
         aliases.get(ref)
-          .map(Alias(_, name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated))
+          .map(Alias(_, name)(a.exprId, a.qualifier, isGenerated = a.isGenerated))
           .getOrElse(a)
 
       case a: AttributeReference =>
         aliases.get(a)
-          .map(Alias(_, a.name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)).getOrElse(a)
+          .map(Alias(_, a.name)(a.exprId, a.qualifier, isGenerated = a.isGenerated)).getOrElse(a)
     }
   }
 }
@@ -204,20 +205,13 @@ object Unions {
 }
 
 /**
- * A pattern that finds the original expression from a sequence of casts.
+ * Extractor for retrieving Int value.
  */
-object Casts {
-  def unapply(expr: Expression): Option[Attribute] = expr match {
-    case c: Cast => collectCasts(expr)
+object IntegerIndex {
+  def unapply(a: Any): Option[Int] = a match {
+    case Literal(a: Int, IntegerType) => Some(a)
+    // When resolving ordinal in Sort, negative values are extracted for issuing error messages.
+    case UnaryMinus(IntegerLiteral(v)) => Some(-v)
     case _ => None
   }
-
-  @tailrec
-  private def collectCasts(e: Expression): Option[Attribute] = {
-    e match {
-      case e: Cast => collectCasts(e.child)
-      case e: Attribute => Some(e)
-      case _ => None
-    }
-  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index c14dfaf6a1c29..e17c886123288 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.Casts
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types.{DataType, StructType}
 
@@ -294,7 +293,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
       // As the root of the expression, Alias will always take an arbitrary exprId, we need
       // to erase that for equality testing.
       val cleanedExprId =
-        Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
+        Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated)
       BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true)
     case other =>
       BindReferences.bindReference(other, allAttributes, allowFailures = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 27a75326eba07..9ca4f13dd73cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+
 object JoinType {
   def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
     case "inner" => Inner
@@ -66,3 +68,9 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
     "Unsupported natural join type " + tpe)
   override def sql: String = "NATURAL " + tpe.sql
 }
+
+case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
+  require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
+    "Unsupported using join type " + tpe)
+  override def sql: String = "USING " + tpe.sql
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 0e02ad6057d1a..01c1fa40dcfbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -177,7 +177,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
       resolver: Resolver,
       attribute: Attribute): Option[(Attribute, List[String])] = {
     assert(nameParts.length > 1)
-    if (attribute.qualifiers.exists(resolver(_, nameParts.head))) {
+    if (attribute.qualifier.exists(resolver(_, nameParts.head))) {
       // At least one qualifier matches. See if remaining parts match.
       val remainingParts = nameParts.tail
       resolveAsColumn(remainingParts, resolver, attribute)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09ea3fea6a694..09c200fa839c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -94,7 +94,7 @@ case class Generate(
   def output: Seq[Attribute] = {
     val qualified = qualifier.map(q =>
       // prepend the new qualifier to the existed one
-      generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
+      generatorOutput.map(a => a.withQualifier(Some(q)))
     ).getOrElse(generatorOutput)
 
     if (join) child.output ++ qualified else qualified
@@ -298,10 +298,11 @@ case class Join(
       condition.forall(_.dataType == BooleanType)
   }
 
-  // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
-  // to eliminate natural before we mark it resolved.
+  // if not a natural join, use `resolvedExceptNatural`. if it is a natural join or
+  // using join, we still need to eliminate natural or using before we mark it resolved.
   override lazy val resolved: Boolean = joinType match {
     case NaturalJoin(_) => false
+    case UsingJoin(_, _) => false
     case _ => resolvedExceptNatural
   }
 }
@@ -614,7 +615,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
 
 case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode {
 
-  override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
+  override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 1b297525bdafb..c87a2e24bdb48 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -189,6 +189,11 @@ class AnalysisErrorSuite extends AnalysisTest {
       .orderBy('havingCondition.asc),
     "cannot resolve" :: "havingCondition" :: Nil)
 
+  errorTest(
+    "unresolved star expansion in max",
+    testRelation2.groupBy('a)(sum(UnresolvedStar(None))),
+    "Invalid usage of '*'" :: "in expression 'sum'" :: Nil)
+
   errorTest(
     "bad casts",
     testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index ef825e606202f..39166c4f8ef73 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.plans.logical._
 trait AnalysisTest extends PlanTest {
 
   val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
-    val caseSensitiveConf = new SimpleCatalystConf(true)
-    val caseInsensitiveConf = new SimpleCatalystConf(false)
+    val caseSensitiveConf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
+    val caseInsensitiveConf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
 
     val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
     val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index b2613e4909288..9aa685e1e8f55 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
 
 
 class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
-  val conf = new SimpleCatalystConf(true)
+  val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
   val catalog = new SimpleCatalog(conf)
   val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index fcf4ac1967a53..1423a8705af27 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
@@ -35,56 +36,81 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
   lazy val r3 = LocalRelation(aNotNull, bNotNull)
   lazy val r4 = LocalRelation(cNotNull, bNotNull)
 
-  test("natural inner join") {
-    val plan = r1.join(r2, NaturalJoin(Inner), None)
+  test("natural/using inner join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(Inner), None)
+    val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural left join") {
-    val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
+  test("natural/using left join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(LeftOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural right join") {
-    val plan = r1.join(r2, NaturalJoin(RightOuter), None)
+  test("natural/using right join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(RightOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(RightOuter, Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural full outer join") {
-    val plan = r1.join(r2, NaturalJoin(FullOuter), None)
+  test("natural/using full outer join") {
+    val naturalPlan = r1.join(r2, NaturalJoin(FullOuter), None)
+    val usingPlan = r1.join(r2, UsingJoin(FullOuter, Seq(UnresolvedAttribute("a"))), None)
     val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
       Alias(Coalesce(Seq(a, a)), "a")(), b, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural inner join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(Inner), None)
+  test("natural/using inner join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(Inner), None)
+    val usingPlan = r3.join(r4, UsingJoin(Inner, Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
       bNotNull, aNotNull, cNotNull)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural left join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
+  test("natural/using left join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(LeftOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(LeftOuter, Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
       bNotNull, aNotNull, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural right join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(RightOuter), None)
+  test("natural/using right join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(RightOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(RightOuter, Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
       bNotNull, a, cNotNull)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
   }
 
-  test("natural full outer join with no nullability") {
-    val plan = r3.join(r4, NaturalJoin(FullOuter), None)
+  test("natural/using full outer join with no nullability") {
+    val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
+    val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
       Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
-    checkAnalysis(plan, expected)
+    checkAnalysis(naturalPlan, expected)
+    checkAnalysis(usingPlan, expected)
+  }
+
+  test("using unresolved attribute") {
+    val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("d"))), None)
+    val error = intercept[AnalysisException] {
+      SimpleAnalyzer.checkAnalysis(usingPlan)
+    }
+    assert(error.message.contains(
+      "using columns ['d] can not be resolved given input columns: [b, a, c]"))
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 5d688e2fe4412..90e97d718a9fc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -25,7 +25,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
     val a: AttributeReference = AttributeReference("name", IntegerType)()
     val b1 = a.withName("name2").withExprId(id)
     val b2 = a.withExprId(id)
-    val b3 = a.withQualifiers("qualifierName" :: Nil)
+    val b3 = a.withQualifier(Some("qualifierName"))
 
     assert(b1 != b2)
     assert(a != b1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index da43751b0a310..47b79fe462457 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -110,7 +110,10 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
   }
 
   private val caseInsensitiveAnalyzer =
-    new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false))
+    new Analyzer(
+      EmptyCatalog,
+      EmptyFunctionRegistry,
+      new SimpleCatalystConf(caseSensitiveAnalysis = false))
 
   test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
     val plan = caseInsensitiveAnalyzer.execute(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
new file mode 100644
index 0000000000000..a4c8d1c6d2aa8
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry, SimpleCatalog}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class EliminateSortsSuite extends PlanTest {
+  val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
+  val catalog = new SimpleCatalog(conf)
+  val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Eliminate Sorts", Once,
+        EliminateSorts) :: Nil
+  }
+
+  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+  test("Empty order by clause") {
+    val x = testRelation
+
+    val query = x.orderBy()
+    val optimized = Optimize.execute(query.analyze)
+    val correctAnswer = x.analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("All the SortOrder are no-op") {
+    val x = testRelation
+
+    val query = x.orderBy(SortOrder(3, Ascending), SortOrder(-1, Ascending))
+    val optimized = Optimize.execute(analyzer.execute(query))
+    val correctAnswer = analyzer.execute(x)
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Partial order-by clauses contain no-op SortOrder") {
+    val x = testRelation
+
+    val query = x.orderBy(SortOrder(3, Ascending), 'a.asc)
+    val optimized = Optimize.execute(analyzer.execute(query))
+    val correctAnswer = analyzer.execute(x.orderBy('a.asc))
+
+    comparePlans(optimized, correctAnswer)
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
index 048b4f12b9edf..c068e895b6643 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
@@ -219,4 +219,25 @@ class CatalystQlSuite extends PlanTest {
     parser.parsePlan("select * from t where a = (select b from s)")
     parser.parsePlan("select * from t group by g having a > (select b from s)")
   }
+
+  test("using clause in JOIN") {
+    // Tests parsing of using clause for different join types.
+    parser.parsePlan("select * from t1 join t2 using (c1)")
+    parser.parsePlan("select * from t1 join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 left join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 right join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)")
+    parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)")
+    // Tests errors
+    // (1) Empty using clause
+    // (2) Qualified columns in using
+    // (3) Both on and using clause
+    var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()"))
+    assert(error.message.contains("cannot recognize input near ')'"))
+    error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)"))
+    assert(error.message.contains("mismatched input '.'"))
+    error = intercept[AnalysisException](parser.parsePlan("select * from t1" +
+      " join t2 using (c1) on t1.c1 = t2.c1"))
+    assert(error.message.contains("missing EOF at 'on' near ')'"))
+  }
 }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 57e8218f3b93a..acf6c583bbb58 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -236,8 +236,8 @@ public void printPerfMetrics() {
   /**
    * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
    *
-   * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used
-   * to insert records.
+   * Note that the map will be reset for inserting new records, and the returned sorter can NOT be
+   * used to insert records.
    */
   public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
     return new UnsafeKVExternalSorter(
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 51e10b0e936b9..9e08675c3e669 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -198,7 +198,7 @@ private static final class KVComparator extends RecordComparator {
     private final UnsafeRow row2;
     private final int numKeyFields;
 
-    public KVComparator(BaseOrdering ordering, int numKeyFields) {
+    KVComparator(BaseOrdering ordering, int numKeyFields) {
       this.numKeyFields = numKeyFields;
       this.row1 = new UnsafeRow(numKeyFields);
       this.row2 = new UnsafeRow(numKeyFields);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
deleted file mode 100644
index 7234726633c36..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java
+++ /dev/null
@@ -1,946 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.parquet;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.List;
-
-import org.apache.commons.lang.NotImplementedException;
-import org.apache.hadoop.mapreduce.InputSplit;
-import org.apache.hadoop.mapreduce.TaskAttemptContext;
-import org.apache.parquet.Preconditions;
-import org.apache.parquet.bytes.BytesUtils;
-import org.apache.parquet.column.ColumnDescriptor;
-import org.apache.parquet.column.Dictionary;
-import org.apache.parquet.column.Encoding;
-import org.apache.parquet.column.page.*;
-import org.apache.parquet.column.values.ValuesReader;
-import org.apache.parquet.io.api.Binary;
-import org.apache.parquet.schema.OriginalType;
-import org.apache.parquet.schema.PrimitiveType;
-import org.apache.parquet.schema.Type;
-
-import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder;
-import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter;
-import org.apache.spark.sql.execution.vectorized.ColumnVector;
-import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.DecimalType;
-import org.apache.spark.unsafe.types.UTF8String;
-
-import static org.apache.parquet.column.ValuesType.*;
-
-/**
- * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs.
- *
- * This is somewhat based on parquet-mr's ColumnReader.
- *
- * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
- * All of these can be handled efficiently and easily with codegen.
- *
- * This class can either return InternalRows or ColumnarBatches. With whole stage codegen
- * enabled, this class returns ColumnarBatches which offers significant performance gains.
- * TODO: make this always return ColumnarBatches.
- */
-public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBase {
-  /**
-   * Batch of unsafe rows that we assemble and the current index we've returned. Every time this
-   * batch is used up (batchIdx == numBatched), we populated the batch.
-   */
-  private UnsafeRow[] rows = new UnsafeRow[64];
-  private int batchIdx = 0;
-  private int numBatched = 0;
-
-  /**
-   * Used to write variable length columns. Same length as `rows`.
-   */
-  private UnsafeRowWriter[] rowWriters = null;
-  /**
-   * True if the row contains variable length fields.
-   */
-  private boolean containsVarLenFields;
-
-  /**
-   * For each request column, the reader to read this column.
-   * columnsReaders[i] populated the UnsafeRow's attribute at i.
-   */
-  private ColumnReader[] columnReaders;
-
-  /**
-   * The number of rows that have been returned.
-   */
-  private long rowsReturned;
-
-  /**
-   * The number of rows that have been reading, including the current in flight row group.
-   */
-  private long totalCountLoadedSoFar = 0;
-
-  /**
-   * For each column, the annotated original type.
-   */
-  private OriginalType[] originalTypes;
-
-  /**
-   * The default size for varlen columns. The row grows as necessary to accommodate the
-   * largest column.
-   */
-  private static final int DEFAULT_VAR_LEN_SIZE = 32;
-
-  /**
-   * columnBatch object that is used for batch decoding. This is created on first use and triggers
-   * batched decoding. It is not valid to interleave calls to the batched interface with the row
-   * by row RecordReader APIs.
-   * This is only enabled with additional flags for development. This is still a work in progress
-   * and currently unsupported cases will fail with potentially difficult to diagnose errors.
-   * This should be only turned on for development to work on this feature.
-   *
-   * When this is set, the code will branch early on in the RecordReader APIs. There is no shared
-   * code between the path that uses the MR decoders and the vectorized ones.
-   *
-   * TODOs:
-   *  - Implement v2 page formats (just make sure we create the correct decoders).
-   */
-  private ColumnarBatch columnarBatch;
-
-  /**
-   * If true, this class returns batches instead of rows.
-   */
-  private boolean returnColumnarBatch;
-
-  /**
-   * The default config on whether columnarBatch should be offheap.
-   */
-  private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP;
-
-  /**
-   * Tries to initialize the reader for this split. Returns true if this reader supports reading
-   * this split and false otherwise.
-   */
-  public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) {
-    try {
-      initialize(inputSplit, taskAttemptContext);
-      return true;
-    } catch (Exception e) {
-      return false;
-    }
-  }
-
-  /**
-   * Implementation of RecordReader API.
-   */
-  @Override
-  public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
-      throws IOException, InterruptedException {
-    super.initialize(inputSplit, taskAttemptContext);
-    initializeInternal();
-  }
-
-  /**
-   * Utility API that will read all the data in path. This circumvents the need to create Hadoop
-   * objects to use this class. `columns` can contain the list of columns to project.
-   */
-  @Override
-  public void initialize(String path, List columns) throws IOException {
-    super.initialize(path, columns);
-    initializeInternal();
-  }
-
-  @Override
-  public void close() throws IOException {
-    if (columnarBatch != null) {
-      columnarBatch.close();
-      columnarBatch = null;
-    }
-    super.close();
-  }
-
-  @Override
-  public boolean nextKeyValue() throws IOException, InterruptedException {
-    if (returnColumnarBatch) return nextBatch();
-
-    if (batchIdx >= numBatched) {
-      if (vectorizedDecode()) {
-        if (!nextBatch()) return false;
-      } else {
-        if (!loadBatch()) return false;
-      }
-    }
-    ++batchIdx;
-    return true;
-  }
-
-  @Override
-  public Object getCurrentValue() throws IOException, InterruptedException {
-    if (returnColumnarBatch) return columnarBatch;
-
-    if (vectorizedDecode()) {
-      return columnarBatch.getRow(batchIdx - 1);
-    } else {
-      return rows[batchIdx - 1];
-    }
-  }
-
-  @Override
-  public float getProgress() throws IOException, InterruptedException {
-    return (float) rowsReturned / totalRowCount;
-  }
-
-  /**
-   * Returns the ColumnarBatch object that will be used for all rows returned by this reader.
-   * This object is reused. Calling this enables the vectorized reader. This should be called
-   * before any calls to nextKeyValue/nextBatch.
-   */
-  public ColumnarBatch resultBatch() {
-    return resultBatch(DEFAULT_MEMORY_MODE);
-  }
-
-  public ColumnarBatch resultBatch(MemoryMode memMode) {
-    if (columnarBatch == null) {
-      columnarBatch = ColumnarBatch.allocate(sparkSchema, memMode);
-    }
-    return columnarBatch;
-  }
-
-  /**
-   * Can be called before any rows are returned to enable returning columnar batches directly.
-   */
-  public void enableReturningBatches() {
-    assert(vectorizedDecode());
-    returnColumnarBatch = true;
-  }
-
-  /**
-   * Advances to the next batch of rows. Returns false if there are no more.
-   */
-  public boolean nextBatch() throws IOException {
-    assert(vectorizedDecode());
-    columnarBatch.reset();
-    if (rowsReturned >= totalRowCount) return false;
-    checkEndOfRowGroup();
-
-    int num = (int)Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
-    for (int i = 0; i < columnReaders.length; ++i) {
-      columnReaders[i].readBatch(num, columnarBatch.column(i));
-    }
-    rowsReturned += num;
-    columnarBatch.setNumRows(num);
-    numBatched = num;
-    batchIdx = 0;
-    return true;
-  }
-
-  /**
-   * Returns true if we are doing a vectorized decode.
-   */
-  private boolean vectorizedDecode() { return columnarBatch != null; }
-
-  private void initializeInternal() throws IOException {
-    /**
-     * Check that the requested schema is supported.
-     */
-    int numVarLenFields = 0;
-    originalTypes = new OriginalType[requestedSchema.getFieldCount()];
-    for (int i = 0; i < requestedSchema.getFieldCount(); ++i) {
-      Type t = requestedSchema.getFields().get(i);
-      if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) {
-        throw new IOException("Complex types not supported.");
-      }
-      PrimitiveType primitiveType = t.asPrimitiveType();
-
-      originalTypes[i] = t.getOriginalType();
-
-      // TODO: Be extremely cautious in what is supported. Expand this.
-      if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
-          originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE &&
-          originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) {
-        throw new IOException("Unsupported type: " + t);
-      }
-      if (originalTypes[i] == OriginalType.DECIMAL &&
-          primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
-        throw new IOException("Decimal with high precision is not supported.");
-      }
-      if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
-        throw new IOException("Int96 not supported.");
-      }
-      ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i));
-      if (!fd.equals(requestedSchema.getColumns().get(i))) {
-        throw new IOException("Schema evolution not supported.");
-      }
-
-      if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.BINARY) {
-        ++numVarLenFields;
-      }
-    }
-
-    /**
-     * Initialize rows and rowWriters. These objects are reused across all rows in the relation.
-     */
-    containsVarLenFields = numVarLenFields > 0;
-    rowWriters = new UnsafeRowWriter[rows.length];
-
-    for (int i = 0; i < rows.length; ++i) {
-      rows[i] = new UnsafeRow(requestedSchema.getFieldCount());
-      BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE);
-      rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount());
-    }
-  }
-
-  /**
-   * Decodes a batch of values into `rows`. This function is the hot path.
-   */
-  private boolean loadBatch() throws IOException {
-    // no more records left
-    if (rowsReturned >= totalRowCount) { return false; }
-    checkEndOfRowGroup();
-
-    int num = (int)Math.min(rows.length, totalCountLoadedSoFar - rowsReturned);
-    rowsReturned += num;
-
-    if (containsVarLenFields) {
-      for (int i = 0; i < rowWriters.length; ++i) {
-        rowWriters[i].holder().reset();
-      }
-    }
-
-    for (int i = 0; i < columnReaders.length; ++i) {
-      switch (columnReaders[i].descriptor.getType()) {
-        case BOOLEAN:
-          decodeBooleanBatch(i, num);
-          break;
-        case INT32:
-          if (originalTypes[i] == OriginalType.DECIMAL) {
-            decodeIntAsDecimalBatch(i, num);
-          } else {
-            decodeIntBatch(i, num);
-          }
-          break;
-        case INT64:
-          Preconditions.checkState(originalTypes[i] == null
-              || originalTypes[i] == OriginalType.DECIMAL,
-              "Unexpected original type: " + originalTypes[i]);
-          decodeLongBatch(i, num);
-          break;
-        case FLOAT:
-          decodeFloatBatch(i, num);
-          break;
-        case DOUBLE:
-          decodeDoubleBatch(i, num);
-          break;
-        case BINARY:
-          decodeBinaryBatch(i, num);
-          break;
-        case FIXED_LEN_BYTE_ARRAY:
-          Preconditions.checkState(originalTypes[i] == OriginalType.DECIMAL,
-              "Unexpected original type: " + originalTypes[i]);
-          decodeFixedLenArrayAsDecimalBatch(i, num);
-          break;
-        case INT96:
-          throw new IOException("Unsupported " + columnReaders[i].descriptor.getType());
-      }
-    }
-
-    numBatched = num;
-    batchIdx = 0;
-
-    // Update the total row lengths if the schema contained variable length. We did not maintain
-    // this as we populated the columns.
-    if (containsVarLenFields) {
-      for (int i = 0; i < numBatched; ++i) {
-        rows[i].setTotalSize(rowWriters[i].holder().totalSize());
-      }
-    }
-
-    return true;
-  }
-
-  private void decodeBooleanBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        rows[n].setBoolean(col, columnReaders[col].nextBoolean());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeIntBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        rows[n].setInt(col, columnReaders[col].nextInt());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeIntAsDecimalBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        // Since this is stored as an INT, it is always a compact decimal. Just set it as a long.
-        rows[n].setLong(col, columnReaders[col].nextInt());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeLongBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        rows[n].setLong(col, columnReaders[col].nextLong());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeFloatBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        rows[n].setFloat(col, columnReaders[col].nextFloat());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeDoubleBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        rows[n].setDouble(col, columnReaders[col].nextDouble());
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeBinaryBatch(int col, int num) throws IOException {
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer();
-        int len = bytes.remaining();
-        if (originalTypes[col] == OriginalType.UTF8) {
-          UTF8String str =
-              UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len);
-          rowWriters[n].write(col, str);
-        } else {
-          rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len);
-        }
-        rows[n].setNotNullAt(col);
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOException {
-    PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
-    int precision = type.getDecimalMetadata().getPrecision();
-    int scale = type.getDecimalMetadata().getScale();
-    Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
-        "Unsupported precision.");
-
-    for (int n = 0; n < num; ++n) {
-      if (columnReaders[col].next()) {
-        Binary v = columnReaders[col].nextBinary();
-        // Constructs a `Decimal` with an unscaled `Long` value if possible.
-        long unscaled = CatalystRowConverter.binaryToUnscaledLong(v);
-        rows[n].setDecimal(col, Decimal.apply(unscaled, precision, scale), precision);
-      } else {
-        rows[n].setNullAt(col);
-      }
-    }
-  }
-
-  /**
-   *
-   * Decoder to return values from a single column.
-   */
-  private final class ColumnReader {
-    /**
-     * Total number of values read.
-     */
-    private long valuesRead;
-
-    /**
-     * value that indicates the end of the current page. That is,
-     * if valuesRead == endOfPageValueCount, we are at the end of the page.
-     */
-    private long endOfPageValueCount;
-
-    /**
-     * The dictionary, if this column has dictionary encoding.
-     */
-    private final Dictionary dictionary;
-
-    /**
-     * If true, the current page is dictionary encoded.
-     */
-    private boolean useDictionary;
-
-    /**
-     * Maximum definition level for this column.
-     */
-    private final int maxDefLevel;
-
-    /**
-     * Repetition/Definition/Value readers.
-     */
-    private IntIterator repetitionLevelColumn;
-    private IntIterator definitionLevelColumn;
-    private ValuesReader dataColumn;
-
-    // Only set if vectorized decoding is true. This is used instead of the row by row decoding
-    // with `definitionLevelColumn`.
-    private VectorizedRleValuesReader defColumn;
-
-    /**
-     * Total number of values in this column (in this row group).
-     */
-    private final long totalValueCount;
-
-    /**
-     * Total values in the current page.
-     */
-    private int pageValueCount;
-
-    private final PageReader pageReader;
-    private final ColumnDescriptor descriptor;
-
-    public ColumnReader(ColumnDescriptor descriptor, PageReader pageReader)
-        throws IOException {
-      this.descriptor = descriptor;
-      this.pageReader = pageReader;
-      this.maxDefLevel = descriptor.getMaxDefinitionLevel();
-
-      DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
-      if (dictionaryPage != null) {
-        try {
-          this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage);
-          this.useDictionary = true;
-        } catch (IOException e) {
-          throw new IOException("could not decode the dictionary for " + descriptor, e);
-        }
-      } else {
-        this.dictionary = null;
-        this.useDictionary = false;
-      }
-      this.totalValueCount = pageReader.getTotalValueCount();
-      if (totalValueCount == 0) {
-        throw new IOException("totalValueCount == 0");
-      }
-    }
-
-    /**
-     * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned.
-     */
-    public boolean nextBoolean() {
-      if (!useDictionary) {
-        return dataColumn.readBoolean();
-      } else {
-        return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    public int nextInt() {
-      if (!useDictionary) {
-        return dataColumn.readInteger();
-      } else {
-        return dictionary.decodeToInt(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    public long nextLong() {
-      if (!useDictionary) {
-        return dataColumn.readLong();
-      } else {
-        return dictionary.decodeToLong(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    public float nextFloat() {
-      if (!useDictionary) {
-        return dataColumn.readFloat();
-      } else {
-        return dictionary.decodeToFloat(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    public double nextDouble() {
-      if (!useDictionary) {
-        return dataColumn.readDouble();
-      } else {
-        return dictionary.decodeToDouble(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    public Binary nextBinary() {
-      if (!useDictionary) {
-        return dataColumn.readBytes();
-      } else {
-        return dictionary.decodeToBinary(dataColumn.readValueDictionaryId());
-      }
-    }
-
-    /**
-     * Advances to the next value. Returns true if the value is non-null.
-     */
-    private boolean next() throws IOException {
-      if (valuesRead >= endOfPageValueCount) {
-        if (valuesRead >= totalValueCount) {
-          // How do we get here? Throw end of stream exception?
-          return false;
-        }
-        readPage();
-      }
-      ++valuesRead;
-      // TODO: Don't read for flat schemas
-      //repetitionLevel = repetitionLevelColumn.nextInt();
-      return definitionLevelColumn.nextInt() == maxDefLevel;
-    }
-
-    /**
-     * Reads `total` values from this columnReader into column.
-     */
-    private void readBatch(int total, ColumnVector column) throws IOException {
-      int rowId = 0;
-      while (total > 0) {
-        // Compute the number of values we want to read in this page.
-        int leftInPage = (int)(endOfPageValueCount - valuesRead);
-        if (leftInPage == 0) {
-          readPage();
-          leftInPage = (int)(endOfPageValueCount - valuesRead);
-        }
-        int num = Math.min(total, leftInPage);
-        if (useDictionary) {
-          // Read and decode dictionary ids.
-          ColumnVector dictionaryIds = column.reserveDictionaryIds(total);;
-          defColumn.readIntegers(
-              num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-          decodeDictionaryIds(rowId, num, column, dictionaryIds);
-        } else {
-          column.setDictionary(null);
-          switch (descriptor.getType()) {
-            case BOOLEAN:
-              readBooleanBatch(rowId, num, column);
-              break;
-            case INT32:
-              readIntBatch(rowId, num, column);
-              break;
-            case INT64:
-              readLongBatch(rowId, num, column);
-              break;
-            case FLOAT:
-              readFloatBatch(rowId, num, column);
-              break;
-            case DOUBLE:
-              readDoubleBatch(rowId, num, column);
-              break;
-            case BINARY:
-              readBinaryBatch(rowId, num, column);
-              break;
-            case FIXED_LEN_BYTE_ARRAY:
-              readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength());
-              break;
-            default:
-              throw new IOException("Unsupported type: " + descriptor.getType());
-          }
-        }
-
-        valuesRead += num;
-        rowId += num;
-        total -= num;
-      }
-    }
-
-    /**
-     * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
-     */
-    private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
-                                     ColumnVector dictionaryIds) {
-      switch (descriptor.getType()) {
-        case INT32:
-        case INT64:
-        case FLOAT:
-        case DOUBLE:
-        case BINARY:
-          column.setDictionary(dictionary);
-          break;
-
-        case FIXED_LEN_BYTE_ARRAY:
-          // DecimalType written in the legacy mode
-          if (DecimalType.is32BitDecimalType(column.dataType())) {
-            for (int i = rowId; i < rowId + num; ++i) {
-              Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
-              column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
-            }
-          } else if (DecimalType.is64BitDecimalType(column.dataType())) {
-            for (int i = rowId; i < rowId + num; ++i) {
-              Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
-              column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
-            }
-          } else {
-            throw new NotImplementedException();
-          }
-          break;
-
-        default:
-          throw new NotImplementedException("Unsupported type: " + descriptor.getType());
-      }
-    }
-
-    /**
-     * For all the read*Batch functions, reads `num` values from this columnReader into column. It
-     * is guaranteed that num is smaller than the number of values left in the current page.
-     */
-
-    private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
-      assert(column.dataType() == DataTypes.BooleanType);
-      defColumn.readBooleans(
-          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-    }
-
-    private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
-      // This is where we implement support for the valid type conversions.
-      // TODO: implement remaining type conversions
-      if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
-        DecimalType.is32BitDecimalType(column.dataType())) {
-        defColumn.readIntegers(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else if (column.dataType() == DataTypes.ByteType) {
-        defColumn.readBytes(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else if (column.dataType() == DataTypes.ShortType) {
-        defColumn.readShorts(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else {
-        throw new NotImplementedException("Unimplemented type: " + column.dataType());
-      }
-    }
-
-    private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
-      // This is where we implement support for the valid type conversions.
-      if (column.dataType() == DataTypes.LongType ||
-          DecimalType.is64BitDecimalType(column.dataType())) {
-        defColumn.readLongs(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else {
-        throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
-      }
-    }
-
-    private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
-      // This is where we implement support for the valid type conversions.
-      // TODO: support implicit cast to double?
-      if (column.dataType() == DataTypes.FloatType) {
-        defColumn.readFloats(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else {
-        throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
-      }
-    }
-
-    private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
-      // This is where we implement support for the valid type conversions.
-      // TODO: implement remaining type conversions
-      if (column.dataType() == DataTypes.DoubleType) {
-        defColumn.readDoubles(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else {
-        throw new NotImplementedException("Unimplemented type: " + column.dataType());
-      }
-    }
-
-    private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
-      // This is where we implement support for the valid type conversions.
-      // TODO: implement remaining type conversions
-      if (column.isArray()) {
-        defColumn.readBinarys(
-            num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
-      } else {
-        throw new NotImplementedException("Unimplemented type: " + column.dataType());
-      }
-    }
-
-    private void readFixedLenByteArrayBatch(int rowId, int num,
-                                            ColumnVector column, int arrayLen) throws IOException {
-      VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
-      // This is where we implement support for the valid type conversions.
-      // TODO: implement remaining type conversions
-      if (DecimalType.is32BitDecimalType(column.dataType())) {
-        for (int i = 0; i < num; i++) {
-          if (defColumn.readInteger() == maxDefLevel) {
-            column.putInt(rowId + i,
-              (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
-          } else {
-            column.putNull(rowId + i);
-          }
-        }
-      } else if (DecimalType.is64BitDecimalType(column.dataType())) {
-        for (int i = 0; i < num; i++) {
-          if (defColumn.readInteger() == maxDefLevel) {
-            column.putLong(rowId + i,
-                CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
-          } else {
-            column.putNull(rowId + i);
-          }
-        }
-      } else {
-        throw new NotImplementedException("Unimplemented type: " + column.dataType());
-      }
-    }
-
-    private void readPage() throws IOException {
-      DataPage page = pageReader.readPage();
-      // TODO: Why is this a visitor?
-      page.accept(new DataPage.Visitor() {
-        @Override
-        public Void visit(DataPageV1 dataPageV1) {
-          try {
-            readPageV1(dataPageV1);
-            return null;
-          } catch (IOException e) {
-            throw new RuntimeException(e);
-          }
-        }
-
-        @Override
-        public Void visit(DataPageV2 dataPageV2) {
-          try {
-            readPageV2(dataPageV2);
-            return null;
-          } catch (IOException e) {
-            throw new RuntimeException(e);
-          }
-        }
-      });
-    }
-
-    private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset)throws IOException {
-      this.endOfPageValueCount = valuesRead + pageValueCount;
-      if (dataEncoding.usesDictionary()) {
-        this.dataColumn = null;
-        if (dictionary == null) {
-          throw new IOException(
-              "could not read page in col " + descriptor +
-                  " as the dictionary was missing for encoding " + dataEncoding);
-        }
-        if (vectorizedDecode()) {
-          @SuppressWarnings("deprecation")
-          Encoding plainDict = Encoding.PLAIN_DICTIONARY; // var to allow warning suppression
-          if (dataEncoding != plainDict && dataEncoding != Encoding.RLE_DICTIONARY) {
-            throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
-          }
-          this.dataColumn = new VectorizedRleValuesReader();
-        } else {
-          this.dataColumn = dataEncoding.getDictionaryBasedValuesReader(
-              descriptor, VALUES, dictionary);
-        }
-        this.useDictionary = true;
-      } else {
-        if (vectorizedDecode()) {
-          if (dataEncoding != Encoding.PLAIN) {
-            throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
-          }
-          this.dataColumn = new VectorizedPlainValuesReader();
-        } else {
-          this.dataColumn = dataEncoding.getValuesReader(descriptor, VALUES);
-        }
-        this.useDictionary = false;
-      }
-
-      try {
-        dataColumn.initFromPage(pageValueCount, bytes, offset);
-      } catch (IOException e) {
-        throw new IOException("could not read page in col " + descriptor, e);
-      }
-    }
-
-    private void readPageV1(DataPageV1 page) throws IOException {
-      this.pageValueCount = page.getValueCount();
-      ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL);
-      ValuesReader dlReader;
-
-      // Initialize the decoders.
-      if (vectorizedDecode()) {
-        if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) {
-          throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding());
-        }
-        int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
-        this.defColumn = new VectorizedRleValuesReader(bitWidth);
-        dlReader = this.defColumn;
-      } else {
-        dlReader = page.getDlEncoding().getValuesReader(descriptor, DEFINITION_LEVEL);
-      }
-      this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader);
-      this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader);
-      try {
-        byte[] bytes = page.getBytes().toByteArray();
-        rlReader.initFromPage(pageValueCount, bytes, 0);
-        int next = rlReader.getNextOffset();
-        dlReader.initFromPage(pageValueCount, bytes, next);
-        next = dlReader.getNextOffset();
-        initDataReader(page.getValueEncoding(), bytes, next);
-      } catch (IOException e) {
-        throw new IOException("could not read page " + page + " in col " + descriptor, e);
-      }
-    }
-
-    private void readPageV2(DataPageV2 page) throws IOException {
-      this.pageValueCount = page.getValueCount();
-      this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(),
-          page.getRepetitionLevels(), descriptor);
-
-      if (vectorizedDecode()) {
-        int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
-        this.defColumn = new VectorizedRleValuesReader(bitWidth);
-        this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
-        this.defColumn.initFromBuffer(
-            this.pageValueCount, page.getDefinitionLevels().toByteArray());
-      } else {
-        this.definitionLevelColumn = createRLEIterator(descriptor.getMaxDefinitionLevel(),
-            page.getDefinitionLevels(), descriptor);
-      }
-      try {
-        initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
-      } catch (IOException e) {
-        throw new IOException("could not read page " + page + " in col " + descriptor, e);
-      }
-    }
-  }
-
-  private void checkEndOfRowGroup() throws IOException {
-    if (rowsReturned != totalCountLoadedSoFar) return;
-    PageReadStore pages = reader.readNextRowGroup();
-    if (pages == null) {
-      throw new IOException("expecting more rows but reached last block. Read "
-          + rowsReturned + " out of " + totalRowCount);
-    }
-    List columns = requestedSchema.getColumns();
-    columnReaders = new ColumnReader[columns.size()];
-    for (int i = 0; i < columns.size(); ++i) {
-      columnReaders[i] = new ColumnReader(columns.get(i), pages.getPageReader(columns.get(i)));
-    }
-    totalCountLoadedSoFar += pages.getRowCount();
-  }
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
new file mode 100644
index 0000000000000..46c84c5dd4d57
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -0,0 +1,475 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet;
+
+import java.io.IOException;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.parquet.bytes.BytesUtils;
+import org.apache.parquet.column.ColumnDescriptor;
+import org.apache.parquet.column.Dictionary;
+import org.apache.parquet.column.Encoding;
+import org.apache.parquet.column.page.*;
+import org.apache.parquet.column.values.ValuesReader;
+import org.apache.parquet.io.api.Binary;
+
+import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DecimalType;
+
+import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL;
+import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ValuesReaderIntIterator;
+import static org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.createRLEIterator;
+
+/**
+ * Decoder to return values from a single column.
+ */
+public class VectorizedColumnReader {
+  /**
+   * Total number of values read.
+   */
+  private long valuesRead;
+
+  /**
+   * value that indicates the end of the current page. That is,
+   * if valuesRead == endOfPageValueCount, we are at the end of the page.
+   */
+  private long endOfPageValueCount;
+
+  /**
+   * The dictionary, if this column has dictionary encoding.
+   */
+  private final Dictionary dictionary;
+
+  /**
+   * If true, the current page is dictionary encoded.
+   */
+  private boolean useDictionary;
+
+  /**
+   * Maximum definition level for this column.
+   */
+  private final int maxDefLevel;
+
+  /**
+   * Repetition/Definition/Value readers.
+   */
+  private SpecificParquetRecordReaderBase.IntIterator repetitionLevelColumn;
+  private SpecificParquetRecordReaderBase.IntIterator definitionLevelColumn;
+  private ValuesReader dataColumn;
+
+  // Only set if vectorized decoding is true. This is used instead of the row by row decoding
+  // with `definitionLevelColumn`.
+  private VectorizedRleValuesReader defColumn;
+
+  /**
+   * Total number of values in this column (in this row group).
+   */
+  private final long totalValueCount;
+
+  /**
+   * Total values in the current page.
+   */
+  private int pageValueCount;
+
+  private final PageReader pageReader;
+  private final ColumnDescriptor descriptor;
+
+  public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader)
+      throws IOException {
+    this.descriptor = descriptor;
+    this.pageReader = pageReader;
+    this.maxDefLevel = descriptor.getMaxDefinitionLevel();
+
+    DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
+    if (dictionaryPage != null) {
+      try {
+        this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage);
+        this.useDictionary = true;
+      } catch (IOException e) {
+        throw new IOException("could not decode the dictionary for " + descriptor, e);
+      }
+    } else {
+      this.dictionary = null;
+      this.useDictionary = false;
+    }
+    this.totalValueCount = pageReader.getTotalValueCount();
+    if (totalValueCount == 0) {
+      throw new IOException("totalValueCount == 0");
+    }
+  }
+
+  /**
+   * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned.
+   */
+  public boolean nextBoolean() {
+    if (!useDictionary) {
+      return dataColumn.readBoolean();
+    } else {
+      return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  public int nextInt() {
+    if (!useDictionary) {
+      return dataColumn.readInteger();
+    } else {
+      return dictionary.decodeToInt(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  public long nextLong() {
+    if (!useDictionary) {
+      return dataColumn.readLong();
+    } else {
+      return dictionary.decodeToLong(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  public float nextFloat() {
+    if (!useDictionary) {
+      return dataColumn.readFloat();
+    } else {
+      return dictionary.decodeToFloat(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  public double nextDouble() {
+    if (!useDictionary) {
+      return dataColumn.readDouble();
+    } else {
+      return dictionary.decodeToDouble(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  public Binary nextBinary() {
+    if (!useDictionary) {
+      return dataColumn.readBytes();
+    } else {
+      return dictionary.decodeToBinary(dataColumn.readValueDictionaryId());
+    }
+  }
+
+  /**
+   * Advances to the next value. Returns true if the value is non-null.
+   */
+  private boolean next() throws IOException {
+    if (valuesRead >= endOfPageValueCount) {
+      if (valuesRead >= totalValueCount) {
+        // How do we get here? Throw end of stream exception?
+        return false;
+      }
+      readPage();
+    }
+    ++valuesRead;
+    // TODO: Don't read for flat schemas
+    //repetitionLevel = repetitionLevelColumn.nextInt();
+    return definitionLevelColumn.nextInt() == maxDefLevel;
+  }
+
+  /**
+   * Reads `total` values from this columnReader into column.
+   */
+  void readBatch(int total, ColumnVector column) throws IOException {
+    int rowId = 0;
+    while (total > 0) {
+      // Compute the number of values we want to read in this page.
+      int leftInPage = (int) (endOfPageValueCount - valuesRead);
+      if (leftInPage == 0) {
+        readPage();
+        leftInPage = (int) (endOfPageValueCount - valuesRead);
+      }
+      int num = Math.min(total, leftInPage);
+      if (useDictionary) {
+        // Read and decode dictionary ids.
+        ColumnVector dictionaryIds = column.reserveDictionaryIds(total);
+        defColumn.readIntegers(
+            num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+        decodeDictionaryIds(rowId, num, column, dictionaryIds);
+      } else {
+        column.setDictionary(null);
+        switch (descriptor.getType()) {
+          case BOOLEAN:
+            readBooleanBatch(rowId, num, column);
+            break;
+          case INT32:
+            readIntBatch(rowId, num, column);
+            break;
+          case INT64:
+            readLongBatch(rowId, num, column);
+            break;
+          case FLOAT:
+            readFloatBatch(rowId, num, column);
+            break;
+          case DOUBLE:
+            readDoubleBatch(rowId, num, column);
+            break;
+          case BINARY:
+            readBinaryBatch(rowId, num, column);
+            break;
+          case FIXED_LEN_BYTE_ARRAY:
+            readFixedLenByteArrayBatch(rowId, num, column, descriptor.getTypeLength());
+            break;
+          default:
+            throw new IOException("Unsupported type: " + descriptor.getType());
+        }
+      }
+
+      valuesRead += num;
+      rowId += num;
+      total -= num;
+    }
+  }
+
+  /**
+   * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
+   */
+  private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
+                                   ColumnVector dictionaryIds) {
+    switch (descriptor.getType()) {
+      case INT32:
+      case INT64:
+      case FLOAT:
+      case DOUBLE:
+      case BINARY:
+        column.setDictionary(dictionary);
+        break;
+
+      case FIXED_LEN_BYTE_ARRAY:
+        // DecimalType written in the legacy mode
+        if (DecimalType.is32BitDecimalType(column.dataType())) {
+          for (int i = rowId; i < rowId + num; ++i) {
+            Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+            column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
+          }
+        } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+          for (int i = rowId; i < rowId + num; ++i) {
+            Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+            column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
+          }
+        } else {
+          throw new NotImplementedException();
+        }
+        break;
+
+      default:
+        throw new NotImplementedException("Unsupported type: " + descriptor.getType());
+    }
+  }
+
+  /**
+   * For all the read*Batch functions, reads `num` values from this columnReader into column. It
+   * is guaranteed that num is smaller than the number of values left in the current page.
+   */
+
+  private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
+    assert(column.dataType() == DataTypes.BooleanType);
+    defColumn.readBooleans(
+        num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+  }
+
+  private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
+    // This is where we implement support for the valid type conversions.
+    // TODO: implement remaining type conversions
+    if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
+        DecimalType.is32BitDecimalType(column.dataType())) {
+      defColumn.readIntegers(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else if (column.dataType() == DataTypes.ByteType) {
+      defColumn.readBytes(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else if (column.dataType() == DataTypes.ShortType) {
+      defColumn.readShorts(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else {
+      throw new NotImplementedException("Unimplemented type: " + column.dataType());
+    }
+  }
+
+  private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
+    // This is where we implement support for the valid type conversions.
+    if (column.dataType() == DataTypes.LongType ||
+        DecimalType.is64BitDecimalType(column.dataType())) {
+      defColumn.readLongs(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else {
+      throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+    }
+  }
+
+  private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
+    // This is where we implement support for the valid type conversions.
+    // TODO: support implicit cast to double?
+    if (column.dataType() == DataTypes.FloatType) {
+      defColumn.readFloats(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else {
+      throw new UnsupportedOperationException("Unsupported conversion to: " + column.dataType());
+    }
+  }
+
+  private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
+    // This is where we implement support for the valid type conversions.
+    // TODO: implement remaining type conversions
+    if (column.dataType() == DataTypes.DoubleType) {
+      defColumn.readDoubles(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else {
+      throw new NotImplementedException("Unimplemented type: " + column.dataType());
+    }
+  }
+
+  private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
+    // This is where we implement support for the valid type conversions.
+    // TODO: implement remaining type conversions
+    if (column.isArray()) {
+      defColumn.readBinarys(
+          num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
+    } else {
+      throw new NotImplementedException("Unimplemented type: " + column.dataType());
+    }
+  }
+
+  private void readFixedLenByteArrayBatch(int rowId, int num,
+                                          ColumnVector column, int arrayLen) throws IOException {
+    VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
+    // This is where we implement support for the valid type conversions.
+    // TODO: implement remaining type conversions
+    if (DecimalType.is32BitDecimalType(column.dataType())) {
+      for (int i = 0; i < num; i++) {
+        if (defColumn.readInteger() == maxDefLevel) {
+          column.putInt(rowId + i,
+              (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+        } else {
+          column.putNull(rowId + i);
+        }
+      }
+    } else if (DecimalType.is64BitDecimalType(column.dataType())) {
+      for (int i = 0; i < num; i++) {
+        if (defColumn.readInteger() == maxDefLevel) {
+          column.putLong(rowId + i,
+              CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
+        } else {
+          column.putNull(rowId + i);
+        }
+      }
+    } else {
+      throw new NotImplementedException("Unimplemented type: " + column.dataType());
+    }
+  }
+
+  private void readPage() throws IOException {
+    DataPage page = pageReader.readPage();
+    // TODO: Why is this a visitor?
+    page.accept(new DataPage.Visitor() {
+      @Override
+      public Void visit(DataPageV1 dataPageV1) {
+        try {
+          readPageV1(dataPageV1);
+          return null;
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
+      }
+
+      @Override
+      public Void visit(DataPageV2 dataPageV2) {
+        try {
+          readPageV2(dataPageV2);
+          return null;
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
+      }
+    });
+  }
+
+  private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException {
+    this.endOfPageValueCount = valuesRead + pageValueCount;
+    if (dataEncoding.usesDictionary()) {
+      this.dataColumn = null;
+      if (dictionary == null) {
+        throw new IOException(
+            "could not read page in col " + descriptor +
+                " as the dictionary was missing for encoding " + dataEncoding);
+      }
+      @SuppressWarnings("deprecation")
+      Encoding plainDict = Encoding.PLAIN_DICTIONARY; // var to allow warning suppression
+      if (dataEncoding != plainDict && dataEncoding != Encoding.RLE_DICTIONARY) {
+        throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+      }
+      this.dataColumn = new VectorizedRleValuesReader();
+      this.useDictionary = true;
+    } else {
+      if (dataEncoding != Encoding.PLAIN) {
+        throw new NotImplementedException("Unsupported encoding: " + dataEncoding);
+      }
+      this.dataColumn = new VectorizedPlainValuesReader();
+      this.useDictionary = false;
+    }
+
+    try {
+      dataColumn.initFromPage(pageValueCount, bytes, offset);
+    } catch (IOException e) {
+      throw new IOException("could not read page in col " + descriptor, e);
+    }
+  }
+
+  private void readPageV1(DataPageV1 page) throws IOException {
+    this.pageValueCount = page.getValueCount();
+    ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL);
+    ValuesReader dlReader;
+
+    // Initialize the decoders.
+    if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) {
+      throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding());
+    }
+    int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+    this.defColumn = new VectorizedRleValuesReader(bitWidth);
+    dlReader = this.defColumn;
+    this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader);
+    this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader);
+    try {
+      byte[] bytes = page.getBytes().toByteArray();
+      rlReader.initFromPage(pageValueCount, bytes, 0);
+      int next = rlReader.getNextOffset();
+      dlReader.initFromPage(pageValueCount, bytes, next);
+      next = dlReader.getNextOffset();
+      initDataReader(page.getValueEncoding(), bytes, next);
+    } catch (IOException e) {
+      throw new IOException("could not read page " + page + " in col " + descriptor, e);
+    }
+  }
+
+  private void readPageV2(DataPageV2 page) throws IOException {
+    this.pageValueCount = page.getValueCount();
+    this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(),
+        page.getRepetitionLevels(), descriptor);
+
+    int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+    this.defColumn = new VectorizedRleValuesReader(bitWidth);
+    this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn);
+    this.defColumn.initFromBuffer(
+        this.pageValueCount, page.getDefinitionLevels().toByteArray());
+    try {
+      initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0);
+    } catch (IOException e) {
+      throw new IOException("could not read page " + page + " in col " + descriptor, e);
+    }
+  }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
new file mode 100644
index 0000000000000..ef44b62a8b17c
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.parquet.column.ColumnDescriptor;
+import org.apache.parquet.column.page.PageReadStore;
+import org.apache.parquet.schema.OriginalType;
+import org.apache.parquet.schema.PrimitiveType;
+import org.apache.parquet.schema.Type;
+
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
+import org.apache.spark.sql.types.Decimal;
+
+/**
+ * A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the
+ * Parquet column APIs. This is somewhat based on parquet-mr's ColumnReader.
+ *
+ * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
+ * All of these can be handled efficiently and easily with codegen.
+ *
+ * This class can either return InternalRows or ColumnarBatches. With whole stage codegen
+ * enabled, this class returns ColumnarBatches which offers significant performance gains.
+ * TODO: make this always return ColumnarBatches.
+ */
+public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase {
+  /**
+   * Batch of rows that we assemble and the current index we've returned. Every time this
+   * batch is used up (batchIdx == numBatched), we populated the batch.
+   */
+  private int batchIdx = 0;
+  private int numBatched = 0;
+
+  /**
+   * For each request column, the reader to read this column.
+   */
+  private VectorizedColumnReader[] columnReaders;
+
+  /**
+   * The number of rows that have been returned.
+   */
+  private long rowsReturned;
+
+  /**
+   * The number of rows that have been reading, including the current in flight row group.
+   */
+  private long totalCountLoadedSoFar = 0;
+
+  /**
+   * columnBatch object that is used for batch decoding. This is created on first use and triggers
+   * batched decoding. It is not valid to interleave calls to the batched interface with the row
+   * by row RecordReader APIs.
+   * This is only enabled with additional flags for development. This is still a work in progress
+   * and currently unsupported cases will fail with potentially difficult to diagnose errors.
+   * This should be only turned on for development to work on this feature.
+   *
+   * When this is set, the code will branch early on in the RecordReader APIs. There is no shared
+   * code between the path that uses the MR decoders and the vectorized ones.
+   *
+   * TODOs:
+   *  - Implement v2 page formats (just make sure we create the correct decoders).
+   */
+  private ColumnarBatch columnarBatch;
+
+  /**
+   * If true, this class returns batches instead of rows.
+   */
+  private boolean returnColumnarBatch;
+
+  /**
+   * The default config on whether columnarBatch should be offheap.
+   */
+  private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP;
+
+  /**
+   * Tries to initialize the reader for this split. Returns true if this reader supports reading
+   * this split and false otherwise.
+   */
+  public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) {
+    try {
+      initialize(inputSplit, taskAttemptContext);
+      return true;
+    } catch (Exception e) {
+      return false;
+    }
+  }
+
+  /**
+   * Implementation of RecordReader API.
+   */
+  @Override
+  public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
+      throws IOException, InterruptedException {
+    super.initialize(inputSplit, taskAttemptContext);
+    initializeInternal();
+  }
+
+  /**
+   * Utility API that will read all the data in path. This circumvents the need to create Hadoop
+   * objects to use this class. `columns` can contain the list of columns to project.
+   */
+  @Override
+  public void initialize(String path, List columns) throws IOException {
+    super.initialize(path, columns);
+    initializeInternal();
+  }
+
+  @Override
+  public void close() throws IOException {
+    if (columnarBatch != null) {
+      columnarBatch.close();
+      columnarBatch = null;
+    }
+    super.close();
+  }
+
+  @Override
+  public boolean nextKeyValue() throws IOException, InterruptedException {
+    resultBatch();
+
+    if (returnColumnarBatch) return nextBatch();
+
+    if (batchIdx >= numBatched) {
+      if (!nextBatch()) return false;
+    }
+    ++batchIdx;
+    return true;
+  }
+
+  @Override
+  public Object getCurrentValue() throws IOException, InterruptedException {
+    if (returnColumnarBatch) return columnarBatch;
+    return columnarBatch.getRow(batchIdx - 1);
+  }
+
+  @Override
+  public float getProgress() throws IOException, InterruptedException {
+    return (float) rowsReturned / totalRowCount;
+  }
+
+  /**
+   * Returns the ColumnarBatch object that will be used for all rows returned by this reader.
+   * This object is reused. Calling this enables the vectorized reader. This should be called
+   * before any calls to nextKeyValue/nextBatch.
+   */
+  public ColumnarBatch resultBatch() {
+    return resultBatch(DEFAULT_MEMORY_MODE);
+  }
+
+  public ColumnarBatch resultBatch(MemoryMode memMode) {
+    if (columnarBatch == null) {
+      columnarBatch = ColumnarBatch.allocate(sparkSchema, memMode);
+    }
+    return columnarBatch;
+  }
+
+  /**
+   * Can be called before any rows are returned to enable returning columnar batches directly.
+   */
+  public void enableReturningBatches() {
+    returnColumnarBatch = true;
+  }
+
+  /**
+   * Advances to the next batch of rows. Returns false if there are no more.
+   */
+  public boolean nextBatch() throws IOException {
+    columnarBatch.reset();
+    if (rowsReturned >= totalRowCount) return false;
+    checkEndOfRowGroup();
+
+    int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
+    for (int i = 0; i < columnReaders.length; ++i) {
+      columnReaders[i].readBatch(num, columnarBatch.column(i));
+    }
+    rowsReturned += num;
+    columnarBatch.setNumRows(num);
+    numBatched = num;
+    batchIdx = 0;
+    return true;
+  }
+
+  private void initializeInternal() throws IOException {
+    /**
+     * Check that the requested schema is supported.
+     */
+    OriginalType[] originalTypes = new OriginalType[requestedSchema.getFieldCount()];
+    for (int i = 0; i < requestedSchema.getFieldCount(); ++i) {
+      Type t = requestedSchema.getFields().get(i);
+      if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) {
+        throw new IOException("Complex types not supported.");
+      }
+      PrimitiveType primitiveType = t.asPrimitiveType();
+
+      originalTypes[i] = t.getOriginalType();
+
+      // TODO: Be extremely cautious in what is supported. Expand this.
+      if (originalTypes[i] != null && originalTypes[i] != OriginalType.DECIMAL &&
+          originalTypes[i] != OriginalType.UTF8 && originalTypes[i] != OriginalType.DATE &&
+          originalTypes[i] != OriginalType.INT_8 && originalTypes[i] != OriginalType.INT_16) {
+        throw new IOException("Unsupported type: " + t);
+      }
+      if (originalTypes[i] == OriginalType.DECIMAL &&
+          primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
+        throw new IOException("Decimal with high precision is not supported.");
+      }
+      if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
+        throw new IOException("Int96 not supported.");
+      }
+      ColumnDescriptor fd = fileSchema.getColumnDescription(requestedSchema.getPaths().get(i));
+      if (!fd.equals(requestedSchema.getColumns().get(i))) {
+        throw new IOException("Schema evolution not supported.");
+      }
+    }
+  }
+
+  private void checkEndOfRowGroup() throws IOException {
+    if (rowsReturned != totalCountLoadedSoFar) return;
+    PageReadStore pages = reader.readNextRowGroup();
+    if (pages == null) {
+      throw new IOException("expecting more rows but reached last block. Read "
+          + rowsReturned + " out of " + totalRowCount);
+    }
+    List columns = requestedSchema.getColumns();
+    columnReaders = new VectorizedColumnReader[columns.size()];
+    for (int i = 0; i < columns.size(); ++i) {
+      columnReaders[i] = new VectorizedColumnReader(columns.get(i),
+          pages.getPageReader(columns.get(i)));
+    }
+    totalCountLoadedSoFar += pages.getRowCount();
+  }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index ffcc9c2ace54e..13bf4c5c77266 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -94,7 +94,7 @@ protected Array(ColumnVector data) {
     }
 
     @Override
-    public final int numElements() { return length; }
+    public int numElements() { return length; }
 
     @Override
     public ArrayData copy() {
@@ -109,62 +109,62 @@ public Object[] array() {
 
       if (dt instanceof BooleanType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getBoolean(offset + i);
           }
         }
       } else if (dt instanceof ByteType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getByte(offset + i);
           }
         }
       } else if (dt instanceof ShortType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getShort(offset + i);
           }
         }
       } else if (dt instanceof IntegerType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getInt(offset + i);
           }
         }
       } else if (dt instanceof FloatType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getFloat(offset + i);
           }
         }
       } else if (dt instanceof DoubleType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getDouble(offset + i);
           }
         }
       } else if (dt instanceof LongType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = data.getLong(offset + i);
           }
         }
       } else if (dt instanceof DecimalType) {
         DecimalType decType = (DecimalType)dt;
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = getDecimal(i, decType.precision(), decType.scale());
           }
         }
       } else if (dt instanceof StringType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = getUTF8String(i).toString();
           }
         }
       } else if (dt instanceof CalendarIntervalType) {
         for (int i = 0; i < length; i++) {
-          if (!data.getIsNull(offset + i)) {
+          if (!data.isNullAt(offset + i)) {
             list[i] = getInterval(i);
           }
         }
@@ -175,10 +175,10 @@ public Object[] array() {
     }
 
     @Override
-    public final boolean isNullAt(int ordinal) { return data.getIsNull(offset + ordinal); }
+    public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); }
 
     @Override
-    public final boolean getBoolean(int ordinal) {
+    public boolean getBoolean(int ordinal) {
       throw new NotImplementedException();
     }
 
@@ -314,7 +314,7 @@ public void reset() {
   /**
    * Returns whether the value at rowId is NULL.
    */
-  public abstract boolean getIsNull(int rowId);
+  public abstract boolean isNullAt(int rowId);
 
   /**
    * Sets the value at rowId to `value`.
@@ -500,6 +500,15 @@ public ColumnarBatch.Row getStruct(int rowId) {
     return resultStruct;
   }
 
+  /**
+   * Returns a utility object to get structs.
+   * provided to keep API compabilitity with InternalRow for code generation
+   */
+  public ColumnarBatch.Row getStruct(int rowId, int size) {
+    resultStruct.rowId = rowId;
+    return resultStruct;
+  }
+
   /**
    * Returns the array at rowid.
    */
@@ -531,6 +540,13 @@ private Array getByteArray(int rowId) {
     return array;
   }
 
+  /**
+   * Returns the value for rowId.
+   */
+  public MapData getMap(int ordinal) {
+    throw new NotImplementedException();
+  }
+
   /**
    * Returns the decimal for rowId.
    */
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index b084eda6f84c1..2dc57dc50d691 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -105,7 +105,7 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) {
       int[] result = new int[array.length];
       ColumnVector data = array.data;
       for (int i = 0; i < result.length; i++) {
-        if (data.getIsNull(array.offset + i)) {
+        if (data.isNullAt(array.offset + i)) {
           throw new RuntimeException("Cannot handle NULL values.");
         }
         result[i] = data.getInt(array.offset + i);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index c462ab1a13bb3..7ab4cda5a4126 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -115,20 +115,20 @@ protected Row(ColumnVector[] columns) {
      * Marks this row as being filtered out. This means a subsequent iteration over the rows
      * in this batch will not include this row.
      */
-    public final void markFiltered() {
+    public void markFiltered() {
       parent.markFiltered(rowId);
     }
 
     public ColumnVector[] columns() { return columns; }
 
     @Override
-    public final int numFields() { return columns.length; }
+    public int numFields() { return columns.length; }
 
     @Override
     /**
      * Revisit this. This is expensive. This is currently only used in test paths.
      */
-    public final InternalRow copy() {
+    public InternalRow copy() {
       GenericMutableRow row = new GenericMutableRow(columns.length);
       for (int i = 0; i < numFields(); i++) {
         if (isNullAt(i)) {
@@ -163,73 +163,73 @@ public final InternalRow copy() {
     }
 
     @Override
-    public final boolean anyNull() {
+    public boolean anyNull() {
       throw new NotImplementedException();
     }
 
     @Override
-    public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); }
+    public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
 
     @Override
-    public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
+    public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
 
     @Override
-    public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
+    public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
 
     @Override
-    public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
+    public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
 
     @Override
-    public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
+    public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
 
     @Override
-    public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
+    public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
 
     @Override
-    public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
+    public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
 
     @Override
-    public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
+    public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
 
     @Override
-    public final Decimal getDecimal(int ordinal, int precision, int scale) {
+    public Decimal getDecimal(int ordinal, int precision, int scale) {
       return columns[ordinal].getDecimal(rowId, precision, scale);
     }
 
     @Override
-    public final UTF8String getUTF8String(int ordinal) {
+    public UTF8String getUTF8String(int ordinal) {
       return columns[ordinal].getUTF8String(rowId);
     }
 
     @Override
-    public final byte[] getBinary(int ordinal) {
+    public byte[] getBinary(int ordinal) {
       return columns[ordinal].getBinary(rowId);
     }
 
     @Override
-    public final CalendarInterval getInterval(int ordinal) {
+    public CalendarInterval getInterval(int ordinal) {
       final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
       final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
       return new CalendarInterval(months, microseconds);
     }
 
     @Override
-    public final InternalRow getStruct(int ordinal, int numFields) {
+    public InternalRow getStruct(int ordinal, int numFields) {
       return columns[ordinal].getStruct(rowId);
     }
 
     @Override
-    public final ArrayData getArray(int ordinal) {
+    public ArrayData getArray(int ordinal) {
       return columns[ordinal].getArray(rowId);
     }
 
     @Override
-    public final MapData getMap(int ordinal) {
+    public MapData getMap(int ordinal) {
       throw new NotImplementedException();
     }
 
     @Override
-    public final Object get(int ordinal, DataType dataType) {
+    public Object get(int ordinal, DataType dataType) {
       throw new NotImplementedException();
     }
   }
@@ -295,7 +295,7 @@ public void setNumRows(int numRows) {
     for (int ordinal : nullFilteredColumns) {
       if (columns[ordinal].numNulls != 0) {
         for (int rowId = 0; rowId < numRows; rowId++) {
-          if (!filteredRows[rowId] && columns[ordinal].getIsNull(rowId)) {
+          if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) {
             filteredRows[rowId] = true;
             ++numRowsFiltered;
           }
@@ -357,7 +357,7 @@ public ColumnarBatch.Row getRow(int rowId) {
    * Marks this row as being filtered out. This means a subsequent iteration over the rows
    * in this batch will not include this row.
    */
-  public final void markFiltered(int rowId) {
+  public void markFiltered(int rowId) {
     assert(!filteredRows[rowId]);
     filteredRows[rowId] = true;
     ++numRowsFiltered;
@@ -367,7 +367,7 @@ public final void markFiltered(int rowId) {
    * Marks a given column as non-nullable. Any row that has a NULL value for the corresponding
    * attribute is filtered out.
    */
-  public final void filterNullsInColumn(int ordinal) {
+  public void filterNullsInColumn(int ordinal) {
     nullFilteredColumns.add(ordinal);
   }
 
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index b06b7f2457b54..689e6a2a6d82f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -52,7 +52,7 @@ protected OffHeapColumnVector(int capacity, DataType type) {
   }
 
   @Override
-  public final long valuesNativeAddress() {
+  public long valuesNativeAddress() {
     return data;
   }
 
@@ -62,7 +62,7 @@ public long nullsNativeAddress() {
   }
 
   @Override
-  public final void close() {
+  public void close() {
     Platform.freeMemory(nulls);
     Platform.freeMemory(data);
     Platform.freeMemory(lengthData);
@@ -78,19 +78,19 @@ public final void close() {
   //
 
   @Override
-  public final void putNotNull(int rowId) {
+  public void putNotNull(int rowId) {
     Platform.putByte(null, nulls + rowId, (byte) 0);
   }
 
   @Override
-  public final void putNull(int rowId) {
+  public void putNull(int rowId) {
     Platform.putByte(null, nulls + rowId, (byte) 1);
     ++numNulls;
     anyNullsSet = true;
   }
 
   @Override
-  public final void putNulls(int rowId, int count) {
+  public void putNulls(int rowId, int count) {
     long offset = nulls + rowId;
     for (int i = 0; i < count; ++i, ++offset) {
       Platform.putByte(null, offset, (byte) 1);
@@ -100,7 +100,7 @@ public final void putNulls(int rowId, int count) {
   }
 
   @Override
-  public final void putNotNulls(int rowId, int count) {
+  public void putNotNulls(int rowId, int count) {
     if (!anyNullsSet) return;
     long offset = nulls + rowId;
     for (int i = 0; i < count; ++i, ++offset) {
@@ -109,7 +109,7 @@ public final void putNotNulls(int rowId, int count) {
   }
 
   @Override
-  public final boolean getIsNull(int rowId) {
+  public boolean isNullAt(int rowId) {
     return Platform.getByte(null, nulls + rowId) == 1;
   }
 
@@ -118,12 +118,12 @@ public final boolean getIsNull(int rowId) {
   //
 
   @Override
-  public final void putBoolean(int rowId, boolean value) {
+  public void putBoolean(int rowId, boolean value) {
     Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0));
   }
 
   @Override
-  public final void putBooleans(int rowId, int count, boolean value) {
+  public void putBooleans(int rowId, int count, boolean value) {
     byte v = (byte)((value) ? 1 : 0);
     for (int i = 0; i < count; ++i) {
       Platform.putByte(null, data + rowId + i, v);
@@ -131,32 +131,32 @@ public final void putBooleans(int rowId, int count, boolean value) {
   }
 
   @Override
-  public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
+  public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; }
 
   //
   // APIs dealing with Bytes
   //
 
   @Override
-  public final void putByte(int rowId, byte value) {
+  public void putByte(int rowId, byte value) {
     Platform.putByte(null, data + rowId, value);
 
   }
 
   @Override
-  public final void putBytes(int rowId, int count, byte value) {
+  public void putBytes(int rowId, int count, byte value) {
     for (int i = 0; i < count; ++i) {
       Platform.putByte(null, data + rowId + i, value);
     }
   }
 
   @Override
-  public final void putBytes(int rowId, int count, byte[] src, int srcIndex) {
+  public void putBytes(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId, count);
   }
 
   @Override
-  public final byte getByte(int rowId) {
+  public byte getByte(int rowId) {
     if (dictionary == null) {
       return Platform.getByte(null, data + rowId);
     } else {
@@ -169,12 +169,12 @@ public final byte getByte(int rowId) {
   //
 
   @Override
-  public final void putShort(int rowId, short value) {
+  public void putShort(int rowId, short value) {
     Platform.putShort(null, data + 2 * rowId, value);
   }
 
   @Override
-  public final void putShorts(int rowId, int count, short value) {
+  public void putShorts(int rowId, int count, short value) {
     long offset = data + 2 * rowId;
     for (int i = 0; i < count; ++i, offset += 4) {
       Platform.putShort(null, offset, value);
@@ -182,13 +182,13 @@ public final void putShorts(int rowId, int count, short value) {
   }
 
   @Override
-  public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+  public void putShorts(int rowId, int count, short[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2,
         null, data + 2 * rowId, count * 2);
   }
 
   @Override
-  public final short getShort(int rowId) {
+  public short getShort(int rowId) {
     if (dictionary == null) {
       return Platform.getShort(null, data + 2 * rowId);
     } else {
@@ -201,12 +201,12 @@ public final short getShort(int rowId) {
   //
 
   @Override
-  public final void putInt(int rowId, int value) {
+  public void putInt(int rowId, int value) {
     Platform.putInt(null, data + 4 * rowId, value);
   }
 
   @Override
-  public final void putInts(int rowId, int count, int value) {
+  public void putInts(int rowId, int count, int value) {
     long offset = data + 4 * rowId;
     for (int i = 0; i < count; ++i, offset += 4) {
       Platform.putInt(null, offset, value);
@@ -214,19 +214,19 @@ public final void putInts(int rowId, int count, int value) {
   }
 
   @Override
-  public final void putInts(int rowId, int count, int[] src, int srcIndex) {
+  public void putInts(int rowId, int count, int[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.INT_ARRAY_OFFSET + srcIndex * 4,
         null, data + 4 * rowId, count * 4);
   }
 
   @Override
-  public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
+  public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET,
         null, data + 4 * rowId, count * 4);
   }
 
   @Override
-  public final int getInt(int rowId) {
+  public int getInt(int rowId) {
     if (dictionary == null) {
       return Platform.getInt(null, data + 4 * rowId);
     } else {
@@ -239,12 +239,12 @@ public final int getInt(int rowId) {
   //
 
   @Override
-  public final void putLong(int rowId, long value) {
+  public void putLong(int rowId, long value) {
     Platform.putLong(null, data + 8 * rowId, value);
   }
 
   @Override
-  public final void putLongs(int rowId, int count, long value) {
+  public void putLongs(int rowId, int count, long value) {
     long offset = data + 8 * rowId;
     for (int i = 0; i < count; ++i, offset += 8) {
       Platform.putLong(null, offset, value);
@@ -252,19 +252,19 @@ public final void putLongs(int rowId, int count, long value) {
   }
 
   @Override
-  public final void putLongs(int rowId, int count, long[] src, int srcIndex) {
+  public void putLongs(int rowId, int count, long[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.LONG_ARRAY_OFFSET + srcIndex * 8,
         null, data + 8 * rowId, count * 8);
   }
 
   @Override
-  public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
+  public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET,
         null, data + 8 * rowId, count * 8);
   }
 
   @Override
-  public final long getLong(int rowId) {
+  public long getLong(int rowId) {
     if (dictionary == null) {
       return Platform.getLong(null, data + 8 * rowId);
     } else {
@@ -277,12 +277,12 @@ public final long getLong(int rowId) {
   //
 
   @Override
-  public final void putFloat(int rowId, float value) {
+  public void putFloat(int rowId, float value) {
     Platform.putFloat(null, data + rowId * 4, value);
   }
 
   @Override
-  public final void putFloats(int rowId, int count, float value) {
+  public void putFloats(int rowId, int count, float value) {
     long offset = data + 4 * rowId;
     for (int i = 0; i < count; ++i, offset += 4) {
       Platform.putFloat(null, offset, value);
@@ -290,19 +290,19 @@ public final void putFloats(int rowId, int count, float value) {
   }
 
   @Override
-  public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+  public void putFloats(int rowId, int count, float[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4,
         null, data + 4 * rowId, count * 4);
   }
 
   @Override
-  public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+  public void putFloats(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
         null, data + rowId * 4, count * 4);
   }
 
   @Override
-  public final float getFloat(int rowId) {
+  public float getFloat(int rowId) {
     if (dictionary == null) {
       return Platform.getFloat(null, data + rowId * 4);
     } else {
@@ -316,12 +316,12 @@ public final float getFloat(int rowId) {
   //
 
   @Override
-  public final void putDouble(int rowId, double value) {
+  public void putDouble(int rowId, double value) {
     Platform.putDouble(null, data + rowId * 8, value);
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, double value) {
+  public void putDoubles(int rowId, int count, double value) {
     long offset = data + 8 * rowId;
     for (int i = 0; i < count; ++i, offset += 8) {
       Platform.putDouble(null, offset, value);
@@ -329,19 +329,19 @@ public final void putDoubles(int rowId, int count, double value) {
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, double[] src, int srcIndex) {
+  public void putDoubles(int rowId, int count, double[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex * 8,
       null, data + 8 * rowId, count * 8);
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
+  public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
         null, data + rowId * 8, count * 8);
   }
 
   @Override
-  public final double getDouble(int rowId) {
+  public double getDouble(int rowId) {
     if (dictionary == null) {
       return Platform.getDouble(null, data + rowId * 8);
     } else {
@@ -353,25 +353,25 @@ public final double getDouble(int rowId) {
   // APIs dealing with Arrays.
   //
   @Override
-  public final void putArray(int rowId, int offset, int length) {
+  public void putArray(int rowId, int offset, int length) {
     assert(offset >= 0 && offset + length <= childColumns[0].capacity);
     Platform.putInt(null, lengthData + 4 * rowId, length);
     Platform.putInt(null, offsetData + 4 * rowId, offset);
   }
 
   @Override
-  public final int getArrayLength(int rowId) {
+  public int getArrayLength(int rowId) {
     return Platform.getInt(null, lengthData + 4 * rowId);
   }
 
   @Override
-  public final int getArrayOffset(int rowId) {
+  public int getArrayOffset(int rowId) {
     return Platform.getInt(null, offsetData + 4 * rowId);
   }
 
   // APIs dealing with ByteArrays
   @Override
-  public final int putByteArray(int rowId, byte[] value, int offset, int length) {
+  public int putByteArray(int rowId, byte[] value, int offset, int length) {
     int result = arrayData().appendBytes(length, value, offset);
     Platform.putInt(null, lengthData + 4 * rowId, length);
     Platform.putInt(null, offsetData + 4 * rowId, result);
@@ -379,7 +379,7 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) {
   }
 
   @Override
-  public final void loadBytes(ColumnVector.Array array) {
+  public void loadBytes(ColumnVector.Array array) {
     if (array.tmpByteArray.length < array.length) array.tmpByteArray = new byte[array.length];
     Platform.copyMemory(
         null, data + array.offset, array.tmpByteArray, Platform.BYTE_ARRAY_OFFSET, array.length);
@@ -388,12 +388,12 @@ public final void loadBytes(ColumnVector.Array array) {
   }
 
   @Override
-  public final void reserve(int requiredCapacity) {
+  public void reserve(int requiredCapacity) {
     if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2);
   }
 
   // Split out the slow path.
-  private final void reserveInternal(int newCapacity) {
+  private void reserveInternal(int newCapacity) {
     if (this.resultArray != null) {
       this.lengthData =
           Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 03160d1ec36ce..f332e87016692 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -52,16 +52,16 @@ protected OnHeapColumnVector(int capacity, DataType type) {
   }
 
   @Override
-  public final long valuesNativeAddress() {
+  public long valuesNativeAddress() {
     throw new RuntimeException("Cannot get native address for on heap column");
   }
   @Override
-  public final long nullsNativeAddress() {
+  public long nullsNativeAddress() {
     throw new RuntimeException("Cannot get native address for on heap column");
   }
 
   @Override
-  public final void close() {
+  public void close() {
   }
 
   //
@@ -69,19 +69,19 @@ public final void close() {
   //
 
   @Override
-  public final void putNotNull(int rowId) {
+  public void putNotNull(int rowId) {
     nulls[rowId] = (byte)0;
   }
 
   @Override
-  public final void putNull(int rowId) {
+  public void putNull(int rowId) {
     nulls[rowId] = (byte)1;
     ++numNulls;
     anyNullsSet = true;
   }
 
   @Override
-  public final void putNulls(int rowId, int count) {
+  public void putNulls(int rowId, int count) {
     for (int i = 0; i < count; ++i) {
       nulls[rowId + i] = (byte)1;
     }
@@ -90,7 +90,7 @@ public final void putNulls(int rowId, int count) {
   }
 
   @Override
-  public final void putNotNulls(int rowId, int count) {
+  public void putNotNulls(int rowId, int count) {
     if (!anyNullsSet) return;
     for (int i = 0; i < count; ++i) {
       nulls[rowId + i] = (byte)0;
@@ -98,7 +98,7 @@ public final void putNotNulls(int rowId, int count) {
   }
 
   @Override
-  public final boolean getIsNull(int rowId) {
+  public boolean isNullAt(int rowId) {
     return nulls[rowId] == 1;
   }
 
@@ -107,12 +107,12 @@ public final boolean getIsNull(int rowId) {
   //
 
   @Override
-  public final void putBoolean(int rowId, boolean value) {
+  public void putBoolean(int rowId, boolean value) {
     byteData[rowId] = (byte)((value) ? 1 : 0);
   }
 
   @Override
-  public final void putBooleans(int rowId, int count, boolean value) {
+  public void putBooleans(int rowId, int count, boolean value) {
     byte v = (byte)((value) ? 1 : 0);
     for (int i = 0; i < count; ++i) {
       byteData[i + rowId] = v;
@@ -120,7 +120,7 @@ public final void putBooleans(int rowId, int count, boolean value) {
   }
 
   @Override
-  public final boolean getBoolean(int rowId) {
+  public boolean getBoolean(int rowId) {
     return byteData[rowId] == 1;
   }
 
@@ -131,24 +131,24 @@ public final boolean getBoolean(int rowId) {
   //
 
   @Override
-  public final void putByte(int rowId, byte value) {
+  public void putByte(int rowId, byte value) {
     byteData[rowId] = value;
   }
 
   @Override
-  public final void putBytes(int rowId, int count, byte value) {
+  public void putBytes(int rowId, int count, byte value) {
     for (int i = 0; i < count; ++i) {
       byteData[i + rowId] = value;
     }
   }
 
   @Override
-  public final void putBytes(int rowId, int count, byte[] src, int srcIndex) {
+  public void putBytes(int rowId, int count, byte[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, byteData, rowId, count);
   }
 
   @Override
-  public final byte getByte(int rowId) {
+  public byte getByte(int rowId) {
     if (dictionary == null) {
       return byteData[rowId];
     } else {
@@ -161,24 +161,24 @@ public final byte getByte(int rowId) {
   //
 
   @Override
-  public final void putShort(int rowId, short value) {
+  public void putShort(int rowId, short value) {
     shortData[rowId] = value;
   }
 
   @Override
-  public final void putShorts(int rowId, int count, short value) {
+  public void putShorts(int rowId, int count, short value) {
     for (int i = 0; i < count; ++i) {
       shortData[i + rowId] = value;
     }
   }
 
   @Override
-  public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
+  public void putShorts(int rowId, int count, short[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, shortData, rowId, count);
   }
 
   @Override
-  public final short getShort(int rowId) {
+  public short getShort(int rowId) {
     if (dictionary == null) {
       return shortData[rowId];
     } else {
@@ -192,24 +192,24 @@ public final short getShort(int rowId) {
   //
 
   @Override
-  public final void putInt(int rowId, int value) {
+  public void putInt(int rowId, int value) {
     intData[rowId] = value;
   }
 
   @Override
-  public final void putInts(int rowId, int count, int value) {
+  public void putInts(int rowId, int count, int value) {
     for (int i = 0; i < count; ++i) {
       intData[i + rowId] = value;
     }
   }
 
   @Override
-  public final void putInts(int rowId, int count, int[] src, int srcIndex) {
+  public void putInts(int rowId, int count, int[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, intData, rowId, count);
   }
 
   @Override
-  public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
+  public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
     int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
     for (int i = 0; i < count; ++i) {
       intData[i + rowId] = Platform.getInt(src, srcOffset);;
@@ -219,7 +219,7 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI
   }
 
   @Override
-  public final int getInt(int rowId) {
+  public int getInt(int rowId) {
     if (dictionary == null) {
       return intData[rowId];
     } else {
@@ -232,24 +232,24 @@ public final int getInt(int rowId) {
   //
 
   @Override
-  public final void putLong(int rowId, long value) {
+  public void putLong(int rowId, long value) {
     longData[rowId] = value;
   }
 
   @Override
-  public final void putLongs(int rowId, int count, long value) {
+  public void putLongs(int rowId, int count, long value) {
     for (int i = 0; i < count; ++i) {
       longData[i + rowId] = value;
     }
   }
 
   @Override
-  public final void putLongs(int rowId, int count, long[] src, int srcIndex) {
+  public void putLongs(int rowId, int count, long[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, longData, rowId, count);
   }
 
   @Override
-  public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
+  public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
     int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
     for (int i = 0; i < count; ++i) {
       longData[i + rowId] = Platform.getLong(src, srcOffset);
@@ -259,7 +259,7 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src
   }
 
   @Override
-  public final long getLong(int rowId) {
+  public long getLong(int rowId) {
     if (dictionary == null) {
       return longData[rowId];
     } else {
@@ -272,26 +272,26 @@ public final long getLong(int rowId) {
   //
 
   @Override
-  public final void putFloat(int rowId, float value) { floatData[rowId] = value; }
+  public void putFloat(int rowId, float value) { floatData[rowId] = value; }
 
   @Override
-  public final void putFloats(int rowId, int count, float value) {
+  public void putFloats(int rowId, int count, float value) {
     Arrays.fill(floatData, rowId, rowId + count, value);
   }
 
   @Override
-  public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
+  public void putFloats(int rowId, int count, float[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, floatData, rowId, count);
   }
 
   @Override
-  public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
+  public void putFloats(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex,
         floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4);
   }
 
   @Override
-  public final float getFloat(int rowId) {
+  public float getFloat(int rowId) {
     if (dictionary == null) {
       return floatData[rowId];
     } else {
@@ -304,28 +304,28 @@ public final float getFloat(int rowId) {
   //
 
   @Override
-  public final void putDouble(int rowId, double value) {
+  public void putDouble(int rowId, double value) {
     doubleData[rowId] = value;
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, double value) {
+  public void putDoubles(int rowId, int count, double value) {
     Arrays.fill(doubleData, rowId, rowId + count, value);
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, double[] src, int srcIndex) {
+  public void putDoubles(int rowId, int count, double[] src, int srcIndex) {
     System.arraycopy(src, srcIndex, doubleData, rowId, count);
   }
 
   @Override
-  public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
+  public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
     Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData,
         Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8);
   }
 
   @Override
-  public final double getDouble(int rowId) {
+  public double getDouble(int rowId) {
     if (dictionary == null) {
       return doubleData[rowId];
     } else {
@@ -338,22 +338,22 @@ public final double getDouble(int rowId) {
   //
 
   @Override
-  public final int getArrayLength(int rowId) {
+  public int getArrayLength(int rowId) {
     return arrayLengths[rowId];
   }
   @Override
-  public final int getArrayOffset(int rowId) {
+  public int getArrayOffset(int rowId) {
     return arrayOffsets[rowId];
   }
 
   @Override
-  public final void putArray(int rowId, int offset, int length) {
+  public void putArray(int rowId, int offset, int length) {
     arrayOffsets[rowId] = offset;
     arrayLengths[rowId] = length;
   }
 
   @Override
-  public final void loadBytes(ColumnVector.Array array) {
+  public void loadBytes(ColumnVector.Array array) {
     array.byteArray = byteData;
     array.byteArrayOffset = array.offset;
   }
@@ -363,7 +363,7 @@ public final void loadBytes(ColumnVector.Array array) {
   //
 
   @Override
-  public final int putByteArray(int rowId, byte[] value, int offset, int length) {
+  public int putByteArray(int rowId, byte[] value, int offset, int length) {
     int result = arrayData().appendBytes(length, value, offset);
     arrayOffsets[rowId] = result;
     arrayLengths[rowId] = length;
@@ -371,12 +371,12 @@ public final int putByteArray(int rowId, byte[] value, int offset, int length) {
   }
 
   @Override
-  public final void reserve(int requiredCapacity) {
+  public void reserve(int requiredCapacity) {
     if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2);
   }
 
   // Spilt this function out since it is the slow path.
-  private final void reserveInternal(int newCapacity) {
+  private void reserveInternal(int newCapacity) {
     if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) {
       int[] newLengths = new int[newCapacity];
       int[] newOffsets = new int[newCapacity];
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
deleted file mode 100644
index 3b30337f1f877..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements.  See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License.  You may obtain a copy of the License at
-*
-*    http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-/**
- * A container for a [[DataFrame]], used for implicit conversions.
- *
- * To use this, import implicit conversions in SQL:
- * {{{
- *   import sqlContext.implicits._
- * }}}
- *
- * @since 1.3.0
- */
-case class DataFrameHolder private[sql](private val df: DataFrame) {
-
-  // This is declared with parentheses to prevent the Scala compiler from treating
-  // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
-  def toDF(): DataFrame = df
-
-  def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 1b5a4999a8ef1..0dc0d44d6cdcd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -289,6 +289,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
    * 
    * 
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • + *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.
  • + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the + * malformed string into a new field configured by `spark.sql.columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    * * @since 1.4.0 */ @@ -313,6 +322,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * (e.g. 00012)
  • *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
  • + *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing.
  • + *
      + *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the + * malformed string into a new field configured by `spark.sql.columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
    • + *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • + *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • + *
    * * @since 1.6.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9951f0fabff15..7ed1c51360f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -138,7 +138,16 @@ final class DataFrameWriter private[sql](df: DataFrame) { /** * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: + * + * - year=2016/month=01/ + * - year=2016/month=02/ + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. * * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ac2ca3c5a35d7..be0dfe7c3344a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -61,20 +61,48 @@ private[sql] object Dataset { } /** - * :: Experimental :: - * A distributed collection of data organized into named columns. + * A [[Dataset]] is a strongly typed collection of domain-specific objects that can be transformed + * in parallel using functional or relational operations. Each Dataset also has an untyped view + * called a [[DataFrame]], which is a Dataset of [[Row]]. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates - * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. + * Operations available on Datasets are divided into transformations and actions. Transformations + * are the ones that produce new Datasets, and actions are the ones that trigger computation and + * return results. Example transformations include map, filter, select, and aggregate (`groupBy`). + * Example actions count, show, or writing data out to file systems. + * + * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally, + * a Dataset represents a logical plan that describes the computation required to produce the data. + * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a + * physical plan for efficient execution in a parallel and distributed manner. To explore the + * logical plan as well as optimized physical plan, use the `explain` function. + * + * To efficiently support domain-specific objects, an [[Encoder]] is required. The encoder maps + * the domain specific type `T` to Spark's internal type system. For example, given a class `Person` + * with two fields, `name` (string) and `age` (int), an encoder is used to tell Spark to generate + * code at runtime to serialize the `Person` object into a binary structure. This binary structure + * often has much lower memory footprint as well as are optimized for efficiency in data processing + * (e.g. in a columnar format). To understand the internal binary representation for data, use the + * `schema` function. + * + * There are typically two ways to create a Dataset. The most common way is by pointing Spark + * to some files on storage systems, using the `read` function available on a `SparkSession`. + * {{{ + * val people = session.read.parquet("...").as[Person] // Scala + * Dataset people = session.read().parquet("...").as(Encoders.bean(Person.class) // Java + * }}} + * + * Datasets can also be created through transformations available on existing Datasets. For example, + * the following creates a new Dataset by applying a filter on the existing one: * {{{ - * val people = sqlContext.read.parquet("...") // in Scala - * DataFrame people = sqlContext.read().parquet("...") // in Java + * val names = people.map(_.name) // in Scala; names is a Dataset[String] + * Dataset names = people.map((Person p) -> p.name, Encoders.STRING) // in Java 8 * }}} * - * Once created, it can be manipulated using the various domain-specific-language (DSL) functions - * defined in: [[DataFrame]] (this class), [[Column]], and [[functions]]. + * Dataset operations can also be untyped, through various domain-specific-language (DSL) + * functions defined in: [[Dataset]] (this class), [[Column]], and [[functions]]. These operations + * are very similar to the operations available in the data frame abstraction in R or Python. * - * To select a column from the data frame, use `apply` method in Scala and `col` in Java. + * To select a column from the Dataset, use `apply` method in Scala and `col` in Java. * {{{ * val ageCol = people("age") // in Scala * Column ageCol = people.col("age") // in Java @@ -89,9 +117,9 @@ private[sql] object Dataset { * * A more concrete example in Scala: * {{{ - * // To create DataFrame using SQLContext - * val people = sqlContext.read.parquet("...") - * val department = sqlContext.read.parquet("...") + * // To create Dataset[Row] using SQLContext + * val people = session.read.parquet("...") + * val department = session.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -101,9 +129,9 @@ private[sql] object Dataset { * * and in Java: * {{{ - * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.read().parquet("..."); - * DataFrame department = sqlContext.read().parquet("..."); + * // To create Dataset using SQLContext + * Dataset people = session.read().parquet("..."); + * Dataset department = session.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -111,14 +139,16 @@ private[sql] object Dataset { * .agg(avg(people.col("salary")), max(people.col("age"))); * }}} * - * @groupname basic Basic DataFrame functions - * @groupname dfops Language Integrated Queries + * @groupname basic Basic Dataset functions + * @groupname action Actions + * @groupname untypedrel Untyped Language Integrated Relational Queries + * @groupname typedrel Typed Language Integrated Relational Queries + * @groupname func Functional Transformations * @groupname rdd RDD Operations * @groupname output Output Operations - * @groupname action Actions - * @since 1.3.0 + * + * @since 1.6.0 */ -@Experimental class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, @DeveloperApi @transient override val queryExecution: QueryExecution, @@ -127,7 +157,7 @@ class Dataset[T] private[sql]( queryExecution.assertAnalyzed() - // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure + // Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { @@ -190,6 +220,7 @@ class Dataset[T] private[sql]( /** * Compose the string representing rows for output + * * @param _numRows Number of rows to show * @param truncate Whether truncate long strings and align cells right */ @@ -222,18 +253,31 @@ class Dataset[T] private[sql]( } /** - * Returns the object itself. + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] + * objects that allow fields to be accessed by ordinal or name. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema)) /** * :: Experimental :: - * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the - * specified type, `U`. + * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The + * method used to map columns depend on the type of `U`: + * - When `U` is a class, fields for the class will be mapped to columns of the same name + * (case sensitivity is determined by `spark.sql.caseSensitive`) + * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * be assigned to `_1`). + * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the + * [[DataFrame]] will be used. + * + * If the schema of the [[Dataset]] does not match the desired `U` type, you can use `select` + * along with `alias` or `as` to rearrange or rename as required. + * * @group basic * @since 1.6.0 */ @@ -241,15 +285,17 @@ class Dataset[T] private[sql]( def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan) /** - * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion - * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. + * This can be quite convenient in conversion from a RDD of tuples into a [[DataFrame]] with + * meaningful names. For example: * {{{ * val rdd: RDD[(Int, String)] = ... - * rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDF() // this implicit conversion creates a DataFrame with column name `_1` and `_2` * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name" * }}} + * * @group basic - * @since 1.3.0 + * @since 2.0.0 */ @scala.annotation.varargs def toDF(colNames: String*): DataFrame = { @@ -265,16 +311,18 @@ class Dataset[T] private[sql]( } /** - * Returns the schema of this [[DataFrame]]. + * Returns the schema of this [[Dataset]]. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def schema: StructType = queryExecution.analyzed.schema /** * Prints the schema to the console in a nice tree format. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ // scalastyle:off println override def printSchema(): Unit = println(schema.treeString) @@ -282,8 +330,9 @@ class Dataset[T] private[sql]( /** * Prints the plans (logical and physical) to the console for debugging purposes. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ override def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) @@ -296,14 +345,17 @@ class Dataset[T] private[sql]( /** * Prints the physical plan to the console for debugging purposes. - * @since 1.3.0 + * + * @group basic + * @since 1.6.0 */ override def explain(): Unit = explain(extended = false) /** * Returns all column names and their data types as an array. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def dtypes: Array[(String, String)] = schema.fields.map { field => (field.name, field.dataType.toString) @@ -311,22 +363,24 @@ class Dataset[T] private[sql]( /** * Returns all column names as an array. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def columns: Array[String] = schema.fields.map(_.name) /** * Returns true if the `collect` and `take` methods can be run locally * (without any Spark executors). + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be - * truncated, and all cells will be aligned right. For example: + * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, + * and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -335,34 +389,36 @@ class Dataset[T] private[sql]( * 1983 03 0.410516 0.442194 * 1984 04 0.450090 0.483521 * }}} + * * @param numRows Number of rows to show * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters * will be truncated, and all cells will be aligned right. + * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def show(): Unit = show(20) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[Dataset]] in a tabular form. * * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right + * be truncated and all cells will be aligned right * * @group action - * @since 1.5.0 + * @since 1.6.0 */ def show(truncate: Boolean): Unit = show(20, truncate) /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[Dataset]] in a tabular form. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -376,7 +432,7 @@ class Dataset[T] private[sql]( * be truncated and all cells will be aligned right * * @group action - * @since 1.5.0 + * @since 1.6.0 */ // scalastyle:off println def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) @@ -386,11 +442,11 @@ class Dataset[T] private[sql]( * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ * // Dropping rows containing any null values. - * df.na.drop() + * ds.na.drop() * }}} * - * @group dfops - * @since 1.3.1 + * @group untypedrel + * @since 1.6.0 */ def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF()) @@ -398,11 +454,11 @@ class Dataset[T] private[sql]( * Returns a [[DataFrameStatFunctions]] for working statistic functions support. * {{{ * // Finding frequent items in column with name 'a'. - * df.stat.freqItems(Seq("a")) + * ds.stat.freqItems(Seq("a")) * }}} * - * @group dfops - * @since 1.4.0 + * @group untypedrel + * @since 1.6.0 */ def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) @@ -412,8 +468,9 @@ class Dataset[T] private[sql]( * Note that cartesian joins are very expensive without an extra filter that can be pushed down. * * @param right Right side of the join operation. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) @@ -436,8 +493,9 @@ class Dataset[T] private[sql]( * * @param right Right side of the join operation. * @param usingColumn Name of the column to join on. This column must exist on both sides. - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame, usingColumn: String): DataFrame = { join(right, Seq(usingColumn)) @@ -460,8 +518,9 @@ class Dataset[T] private[sql]( * * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { join(right, usingColumns, "inner") @@ -480,8 +539,9 @@ class Dataset[T] private[sql]( * @param right Right side of the join operation. * @param usingColumns Names of the columns to join on. This columns must exist on both sides. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. - * @group dfops - * @since 1.6.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right @@ -490,41 +550,12 @@ class Dataset[T] private[sql]( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] - val condition = usingColumns.map { col => - catalyst.expressions.EqualTo( - withPlan(joined.left).resolve(col), - withPlan(joined.right).resolve(col)) - }.reduceLeftOption[catalyst.expressions.BinaryExpression] { (cond, eqTo) => - catalyst.expressions.And(cond, eqTo) - } - - // Project only one of the join columns. - val joinedCols = JoinType(joinType) match { - case Inner | LeftOuter | LeftSemi => - usingColumns.map(col => withPlan(joined.left).resolve(col)) - case RightOuter => - usingColumns.map(col => withPlan(joined.right).resolve(col)) - case FullOuter => - usingColumns.map { col => - val leftCol = withPlan(joined.left).resolve(col).toAttribute.withNullability(true) - val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) - Alias(Coalesce(Seq(leftCol, rightCol)), col)() - } - case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.") - } - // The nullability of output of joined could be different than original column, - // so we can only compare them by exprId - val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references)) - val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_)) withPlan { - Project( - resultCols, - Join( - joined.left, - joined.right, - joinType = JoinType(joinType), - condition) - ) + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.map(UnresolvedAttribute(_))), + None) } } @@ -536,8 +567,9 @@ class Dataset[T] private[sql]( * df1.join(df2, $"df1Key" === $"df2Key") * df1.join(df2).where($"df1Key" === $"df2Key") * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner") @@ -558,8 +590,9 @@ class Dataset[T] private[sql]( * @param right Right side of the join. * @param joinExprs Join expression. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { // Note that in this function, we introduce a hack in the case of self-join to automatically @@ -605,6 +638,7 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to * true. * @@ -619,8 +653,11 @@ class Dataset[T] private[sql]( * @param other Right side of the join. * @param condition Join expression. * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val left = this.logicalPlan val right = other.logicalPlan @@ -649,24 +686,28 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair * where `condition` evaluates to true. * * @param other Right side of the join. * @param condition Join expression. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * Returns a new [[Dataset]] with each partition sorted by the given expressions. * * This is the same operation as "SORT BY" in SQL (Hive QL). * - * @group dfops - * @since 1.6.0 + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { @@ -674,12 +715,12 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with each partition sorted by the given expressions. + * Returns a new [[Dataset]] with each partition sorted by the given expressions. * * This is the same operation as "SORT BY" in SQL (Hive QL). * - * @group dfops - * @since 1.6.0 + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { @@ -687,15 +728,16 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. + * Returns a new [[Dataset]] sorted by the specified column, all in ascending order. * {{{ * // The following 3 are equivalent - * df.sort("sortcol") - * df.sort($"sortcol") - * df.sort($"sortcol".asc) + * ds.sort("sortcol") + * ds.sort($"sortcol") + * ds.sort($"sortcol".asc) * }}} - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): Dataset[T] = { @@ -703,12 +745,13 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] sorted by the given expressions. For example: + * Returns a new [[Dataset]] sorted by the given expressions. For example: * {{{ - * df.sort($"col1", $"col2".desc) + * ds.sort($"col1", $"col2".desc) * }}} - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def sort(sortExprs: Column*): Dataset[T] = { @@ -716,19 +759,21 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] sorted by the given expressions. + * Returns a new [[Dataset]] sorted by the given expressions. * This is an alias of the `sort` function. - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*) /** - * Returns a new [[DataFrame]] sorted by the given expressions. + * Returns a new [[Dataset]] sorted by the given expressions. * This is an alias of the `sort` function. - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*) @@ -736,16 +781,18 @@ class Dataset[T] private[sql]( /** * Selects column based on the column name and return it as a [[Column]]. * Note that the column name can also reference to a nested column like `a.b`. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def apply(colName: String): Column = col(colName) /** * Selects column based on the column name and return it as a [[Column]]. * Note that the column name can also reference to a nested column like `a.b`. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def col(colName: String): Column = colName match { case "*" => @@ -756,42 +803,47 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with an alias set. - * @group dfops - * @since 1.3.0 + * Returns a new [[Dataset]] with an alias set. + * + * @group typedrel + * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { SubqueryAlias(alias, logicalPlan) } /** - * (Scala-specific) Returns a new [[DataFrame]] with an alias set. - * @group dfops - * @since 1.3.0 + * (Scala-specific) Returns a new [[Dataset]] with an alias set. + * + * @group typedrel + * @since 2.0.0 */ def as(alias: Symbol): Dataset[T] = as(alias.name) /** - * Returns a new [[DataFrame]] with an alias set. Same as `as`. - * @group dfops - * @since 1.6.0 + * Returns a new [[Dataset]] with an alias set. Same as `as`. + * + * @group typedrel + * @since 2.0.0 */ def alias(alias: String): Dataset[T] = as(alias) /** - * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`. - * @group dfops - * @since 1.6.0 + * (Scala-specific) Returns a new [[Dataset]] with an alias set. Same as `as`. + * + * @group typedrel + * @since 2.0.0 */ def alias(alias: Symbol): Dataset[T] = as(alias) /** * Selects a set of column based expressions. * {{{ - * df.select($"colA", $"colB" + 1) + * ds.select($"colA", $"colB" + 1) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { @@ -804,11 +856,12 @@ class Dataset[T] private[sql]( * * {{{ * // The following two are equivalent: - * df.select("colA", "colB") - * df.select($"colA", $"colB") + * ds.select("colA", "colB") + * ds.select($"colA", $"colB") * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) @@ -819,11 +872,12 @@ class Dataset[T] private[sql]( * * {{{ * // The following are equivalent: - * df.selectExpr("colA", "colB as newName", "abs(colC)") - * df.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) + * ds.selectExpr("colA", "colB as newName", "abs(colC)") + * ds.select(expr("colA"), expr("colB as newName"), expr("abs(colC)")) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { @@ -833,14 +887,18 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. * * {{{ * val ds = Seq(1, 2, 3).toDS() * val newDS = ds.select(expr("value + 1").as[Int]) * }}} + * + * @group typedrel * @since 1.6.0 */ + @Experimental def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { new Dataset[U1]( sqlContext, @@ -867,16 +925,24 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** + * :: Experimental :: * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -884,9 +950,13 @@ class Dataset[T] private[sql]( selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** + * :: Experimental :: * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def select[U1, U2, U3, U4]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -895,9 +965,13 @@ class Dataset[T] private[sql]( selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** + * :: Experimental :: * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * + * @group typedrel * @since 1.6.0 */ + @Experimental def select[U1, U2, U3, U4, U5]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], @@ -910,11 +984,12 @@ class Dataset[T] private[sql]( * Filters rows using the given condition. * {{{ * // The following are equivalent: - * peopleDf.filter($"age" > 15) - * peopleDf.where($"age" > 15) + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) * }}} - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 1.6.0 */ def filter(condition: Column): Dataset[T] = withTypedPlan { Filter(condition.expr, logicalPlan) @@ -923,10 +998,11 @@ class Dataset[T] private[sql]( /** * Filters rows using the given SQL expression. * {{{ - * peopleDf.filter("age > 15") + * peopleDs.filter("age > 15") * }}} - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 1.6.0 */ def filter(conditionExpr: String): Dataset[T] = { filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) @@ -936,145 +1012,165 @@ class Dataset[T] private[sql]( * Filters rows using the given condition. This is an alias for `filter`. * {{{ * // The following are equivalent: - * peopleDf.filter($"age" > 15) - * peopleDf.where($"age" > 15) + * peopleDs.filter($"age" > 15) + * peopleDs.where($"age" > 15) * }}} - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 1.6.0 */ def where(condition: Column): Dataset[T] = filter(condition) /** * Filters rows using the given SQL expression. * {{{ - * peopleDf.where("age > 15") + * peopleDs.where("age > 15") * }}} - * @group dfops - * @since 1.5.0 + * + * @group typedrel + * @since 1.6.0 */ def where(conditionExpr: String): Dataset[T] = { filter(Column(sqlContext.sessionState.sqlParser.parseExpression(conditionExpr))) } /** - * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns grouped by department. - * df.groupBy($"department").avg() + * ds.groupBy($"department").avg() * * // Compute the max age and average salary, grouped by department and gender. - * df.groupBy($"department", $"gender").agg(Map( + * ds.groupBy($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType) + def groupBy(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) } /** - * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns rolluped by department and group. - * df.rollup($"department", $"group").avg() + * ds.rollup($"department", $"group").avg() * * // Compute the max age and average salary, rolluped by department and gender. - * df.rollup($"department", $"gender").agg(Map( + * ds.rollup($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def rollup(cols: Column*): GroupedData = { - GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType) + def rollup(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) } /** - * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ * // Compute the average for all numeric columns cubed by department and group. - * df.cube($"department", $"group").avg() + * ds.cube($"department", $"group").avg() * * // Compute the max age and average salary, cubed by department and gender. - * df.cube($"department", $"gender").agg(Map( + * ds.cube($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType) + def cube(cols: Column*): RelationalGroupedDataset = { + RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) + } /** - * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * Groups the [[Dataset]] using the specified columns, so that we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ * // Compute the average for all numeric columns grouped by department. - * df.groupBy("department").avg() + * ds.groupBy("department").avg() * * // Compute the max age and average salary, grouped by department and gender. - * df.groupBy($"department", $"gender").agg(Map( + * ds.groupBy($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.3.0 + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): GroupedData = { + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType) } /** + * :: Experimental :: * (Scala-specific) * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. + * + * @group action * @since 1.6.0 */ + @Experimental def reduce(func: (T, T) => T): T = rdd.reduce(func) /** + * :: Experimental :: * (Java-specific) * Reduces the elements of this Dataset using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. + * + * @group action * @since 1.6.0 */ + @Experimental def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** + * :: Experimental :: * (Scala-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. - * @since 1.6.0 + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 */ - def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = { + @Experimental + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) - new GroupedDataset( + new KeyValueGroupedDataset( encoderFor[K], encoderFor[T], executed, @@ -1083,11 +1179,16 @@ class Dataset[T] private[sql]( } /** - * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. - * @since 1.6.0 + * :: Experimental :: + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given [[Column]] + * expressions. + * + * @group typedrel + * @since 2.0.0 */ + @Experimental @scala.annotation.varargs - def groupByKey(cols: Column*): GroupedDataset[Row, T] = { + def groupByKey(cols: Column*): KeyValueGroupedDataset[Row, T] = { val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) @@ -1095,7 +1196,7 @@ class Dataset[T] private[sql]( val dataAttributes = executed.analyzed.output.dropRight(cols.size) val keyAttributes = executed.analyzed.output.takeRight(cols.size) - new GroupedDataset( + new KeyValueGroupedDataset( RowEncoder(keyAttributes.toStructType), encoderFor[T], executed, @@ -1104,133 +1205,150 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * (Java-specific) - * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`. - * @since 1.6.0 + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + @Experimental + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) /** - * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of rollup that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ * // Compute the average for all numeric columns rolluped by department and group. - * df.rollup("department", "group").avg() + * ds.rollup("department", "group").avg() * * // Compute the max age and average salary, rolluped by department and gender. - * df.rollup($"department", $"gender").agg(Map( + * ds.rollup($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def rollup(col1: String, cols: String*): GroupedData = { + def rollup(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType) } /** - * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, * so we can run aggregation on them. - * See [[GroupedData]] for all the available aggregate functions. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of cube that can only group by existing columns using column names * (i.e. cannot construct expressions). * * {{{ * // Compute the average for all numeric columns cubed by department and group. - * df.cube("department", "group").avg() + * ds.cube("department", "group").avg() * * // Compute the max age and average salary, cubed by department and gender. - * df.cube($"department", $"gender").agg(Map( + * ds.cube($"department", $"gender").agg(Map( * "salary" -> "avg", * "age" -> "max" * )) * }}} - * @group dfops - * @since 1.4.0 + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs - def cube(col1: String, cols: String*): GroupedData = { + def cube(col1: String, cols: String*): RelationalGroupedDataset = { val colNames: Seq[String] = col1 +: cols - GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType) + RelationalGroupedDataset( + toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType) } /** - * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg("age" -> "max", "salary" -> "avg") - * df.groupBy().agg("age" -> "max", "salary" -> "avg") + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg("age" -> "max", "salary" -> "avg") + * ds.groupBy().agg("age" -> "max", "salary" -> "avg") * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { groupBy().agg(aggExpr, aggExprs : _*) } /** - * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(Map("age" -> "max", "salary" -> "avg")) - * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** - * (Java-specific) Aggregates on the entire [[DataFrame]] without groups. + * (Java-specific) Aggregates on the entire [[Dataset]] without groups. * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(Map("age" -> "max", "salary" -> "avg")) - * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) /** - * Aggregates on the entire [[DataFrame]] without groups. + * Aggregates on the entire [[Dataset]] without groups. * {{{ - * // df.agg(...) is a shorthand for df.groupBy().agg(...) - * df.agg(max($"age"), avg($"salary")) - * df.groupBy().agg(max($"age"), avg($"salary")) + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(max($"age"), avg($"salary")) + * ds.groupBy().agg(max($"age"), avg($"salary")) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** - * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function - * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. - * @group dfops - * @since 1.3.0 + * Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` is an action and returns an array (by triggering query execution) + * while `limit` returns a new [[Dataset]]. + * + * @group typedrel + * @since 2.0.0 */ def limit(n: Int): Dataset[T] = withTypedPlan { Limit(Literal(n), logicalPlan) } /** - * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. + * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. * This is equivalent to `UNION ALL` in SQL. - * @group dfops - * @since 1.3.0 + * + * To do a SQL-style set union (that does deduplication of elements), use this function followed + * by a [[distinct]]. + * + * @group typedrel + * @since 2.0.0 */ def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan { // This breaks caching, but it's usually ok because it addresses a very specific use case: @@ -1238,62 +1356,90 @@ class Dataset[T] private[sql]( CombineUnions(Union(logicalPlan, other.logicalPlan)) } + /** + * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. + * This is equivalent to `UNION ALL` in SQL. + * + * @group typedrel + * @since 2.0.0 + */ def union(other: Dataset[T]): Dataset[T] = unionAll(other) /** - * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. + * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. - * @group dfops - * @since 1.3.0 + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 1.6.0 */ def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { Intersect(logicalPlan, other.logicalPlan) } /** - * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. + * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT` in SQL. - * @group dfops - * @since 1.3.0 + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.0.0 */ def except(other: Dataset[T]): Dataset[T] = withTypedPlan { Except(logicalPlan, other.logicalPlan) } + /** + * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. + * This is equivalent to `EXCEPT` in SQL. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.0.0 + */ def subtract(other: Dataset[T]): Dataset[T] = except(other) /** - * Returns a new [[DataFrame]] by sampling a fraction of rows. + * Returns a new [[Dataset]] by sampling a fraction of rows. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. * @param seed Seed for sampling. - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } /** - * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. - * @group dfops - * @since 1.3.0 + * + * @group typedrel + * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { sample(withReplacement, fraction, Utils.random.nextLong) } /** - * Randomly splits this [[DataFrame]] with the provided weights. + * Randomly splits this [[Dataset]] with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. - * @group dfops - * @since 1.4.0 + * + * @group typedrel + * @since 2.0.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its @@ -1310,29 +1456,29 @@ class Dataset[T] private[sql]( } /** - * Randomly splits this [[DataFrame]] with the provided weights. + * Randomly splits this [[Dataset]] with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. - * @group dfops - * @since 1.4.0 + * @group typedrel + * @since 2.0.0 */ def randomSplit(weights: Array[Double]): Array[Dataset[T]] = { randomSplit(weights, Utils.random.nextLong) } /** - * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api. + * Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. - * @group dfops */ private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = { randomSplit(weights.toArray, seed) } /** - * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more + * :: Experimental :: + * (Scala-specific) Returns a new [[Dataset]] where each row has been expanded to zero or more * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of * the input row are implicitly joined with each row that is output by the function. * @@ -1341,18 +1487,20 @@ class Dataset[T] private[sql]( * * {{{ * case class Book(title: String, words: String) - * val df: RDD[Book] + * val ds: Dataset[Book] * * case class Word(word: String) - * val allWords = df.explode('words) { + * val allWords = ds.explode('words) { * case Row(words: String) => words.split(" ").map(Word(_)) * } * * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ + @Experimental def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] @@ -1372,16 +1520,19 @@ class Dataset[T] private[sql]( } /** - * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero + * :: Experimental :: + * (Scala-specific) Returns a new [[Dataset]] where a single column has been expanded to zero * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All * columns of the input row are implicitly joined with each value that is output by the function. * * {{{ - * df.explode("words", "word") {words: String => words.split(" ")} + * ds.explode("words", "word") {words: String => words.split(" ")} * }}} - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ + @Experimental def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType @@ -1401,13 +1552,12 @@ class Dataset[T] private[sql]( } } - ///////////////////////////////////////////////////////////////////////////// - /** - * Returns a new [[DataFrame]] by adding a column or replacing the existing column that has + * Returns a new [[Dataset]] by adding a column or replacing the existing column that has * the same name. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def withColumn(colName: String, col: Column): DataFrame = { val resolver = sqlContext.sessionState.analyzer.resolver @@ -1428,7 +1578,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] by adding a column with metadata. + * Returns a new [[Dataset]] by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sqlContext.sessionState.analyzer.resolver @@ -1449,10 +1599,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with a column renamed. + * Returns a new [[Dataset]] with a column renamed. * This is a no-op if schema doesn't contain existingName. - * @group dfops - * @since 1.3.0 + * + * @group untypedrel + * @since 2.0.0 */ def withColumnRenamed(existingName: String, newName: String): DataFrame = { val resolver = sqlContext.sessionState.analyzer.resolver @@ -1473,20 +1624,22 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with a column dropped. + * Returns a new [[Dataset]] with a column dropped. * This is a no-op if schema doesn't contain column name. - * @group dfops - * @since 1.4.0 + * + * @group untypedrel + * @since 2.0.0 */ def drop(colName: String): DataFrame = { drop(Seq(colName) : _*) } /** - * Returns a new [[DataFrame]] with columns dropped. + * Returns a new [[Dataset]] with columns dropped. * This is a no-op if schema doesn't contain column name(s). - * @group dfops - * @since 1.6.0 + * + * @group untypedrel + * @since 2.0.0 */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { @@ -1501,12 +1654,13 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with a column dropped. + * Returns a new [[Dataset]] with a column dropped. * This version of drop accepts a Column rather than a name. - * This is a no-op if the DataFrame doesn't have a column + * This is a no-op if the Datasetdoesn't have a column * with an equivalent expression. - * @group dfops - * @since 1.4.1 + * + * @group untypedrel + * @since 2.0.0 */ def drop(col: Column): DataFrame = { val expression = col match { @@ -1523,19 +1677,20 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. * This is an alias for `distinct`. - * @group dfops - * @since 1.4.0 + * + * @group typedrel + * @since 2.0.0 */ def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns) /** - * (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only + * (Scala-specific) Returns a new [[Dataset]] with duplicate rows removed, considering only * the subset of columns. * - * @group dfops - * @since 1.4.0 + * @group typedrel + * @since 2.0.0 */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val groupCols = colNames.map(resolve) @@ -1551,11 +1706,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] with duplicate rows removed, considering only + * Returns a new [[Dataset]] with duplicate rows removed, considering only * the subset of columns. * - * @group dfops - * @since 1.4.0 + * @group typedrel + * @since 2.0.0 */ def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq) @@ -1564,11 +1719,11 @@ class Dataset[T] private[sql]( * If no columns are given, this function computes statistics for all numerical columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to + * backward compatibility of the schema of the resulting [[Dataset]]. If you want to * programmatically compute summary statistics, use the `agg` function instead. * * {{{ - * df.describe("age", "height").show() + * ds.describe("age", "height").show() * * // output: * // summary age height @@ -1580,7 +1735,7 @@ class Dataset[T] private[sql]( * }}} * * @group action - * @since 1.3.1 + * @since 1.6.0 */ @scala.annotation.varargs def describe(cols: String*): DataFrame = withPlan { @@ -1625,7 +1780,7 @@ class Dataset[T] private[sql]( * all the data is loaded into the driver's memory. * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df => df.collect(needCallback = false) @@ -1634,64 +1789,86 @@ class Dataset[T] private[sql]( /** * Returns the first row. * @group action - * @since 1.3.0 + * @since 1.6.0 */ def head(): T = head(1).head /** * Returns the first row. Alias for head(). * @group action - * @since 1.3.0 + * @since 1.6.0 */ def first(): T = head() /** * Concise syntax for chaining custom transformations. * {{{ - * def featurize(ds: DataFrame) = ... + * def featurize(ds: Dataset[T]): Dataset[U] = ... * - * df + * ds * .transform(featurize) * .transform(...) * }}} + * + * @group func * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) /** + * :: Experimental :: * (Scala-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * + * @group func * @since 1.6.0 */ + @Experimental def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) /** + * :: Experimental :: * (Java-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * + * @group func * @since 1.6.0 */ + @Experimental def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) /** + * :: Experimental :: * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * + * @group func * @since 1.6.0 */ + @Experimental def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** + * :: Experimental :: * (Java-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * + * @group func * @since 1.6.0 */ + @Experimental def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = map(t => func.call(t))(encoder) /** + * :: Experimental :: * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * + * @group func * @since 1.6.0 */ + @Experimental def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sqlContext, @@ -1700,30 +1877,42 @@ class Dataset[T] private[sql]( } /** + * :: Experimental :: * (Java-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * + * @group func * @since 1.6.0 */ + @Experimental def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala mapPartitions(func)(encoder) } /** + * :: Experimental :: * (Scala-specific) * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], * and then flattening the results. + * + * @group func * @since 1.6.0 */ + @Experimental def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) /** + * :: Experimental :: * (Java-specific) * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], * and then flattening the results. + * + * @group func * @since 1.6.0 */ + @Experimental def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { val func: (T) => Iterator[U] = x => f.call(x).asScala flatMap(func)(encoder) @@ -1731,8 +1920,9 @@ class Dataset[T] private[sql]( /** * Applies a function `f` to all rows. - * @group rdd - * @since 1.3.0 + * + * @group action + * @since 1.6.0 */ def foreach(f: T => Unit): Unit = withNewExecutionId { rdd.foreach(f) @@ -1741,14 +1931,17 @@ class Dataset[T] private[sql]( /** * (Java-specific) * Runs `func` on each element of this [[Dataset]]. + * + * @group action * @since 1.6.0 */ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** - * Applies a function f to each partition of this [[DataFrame]]. - * @group rdd - * @since 1.3.0 + * Applies a function f to each partition of this [[Dataset]]. + * + * @group action + * @since 1.6.0 */ def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId { rdd.foreachPartition(f) @@ -1757,24 +1950,26 @@ class Dataset[T] private[sql]( /** * (Java-specific) * Runs `func` on each partition of this [[Dataset]]. + * + * @group action * @since 1.6.0 */ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = foreachPartition(it => func.call(it.asJava)) /** - * Returns the first `n` rows in the [[DataFrame]]. + * Returns the first `n` rows in the [[Dataset]]. * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def take(n: Int): Array[T] = head(n) /** - * Returns the first `n` rows in the [[DataFrame]] as a list. + * Returns the first `n` rows in the [[Dataset]] as a list. * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. @@ -1785,7 +1980,7 @@ class Dataset[T] private[sql]( def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) /** - * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * Returns an array that contains all of [[Row]]s in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. @@ -1793,18 +1988,18 @@ class Dataset[T] private[sql]( * For Java API, use [[collectAsList]]. * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def collect(): Array[T] = collect(needCallback = true) /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * Returns a Java list that contains all of [[Row]]s in this [[Dataset]]. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * * @group action - * @since 1.3.0 + * @since 1.6.0 */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { @@ -1826,31 +2021,32 @@ class Dataset[T] private[sql]( } /** - * Returns the number of rows in the [[DataFrame]]. + * Returns the number of rows in the [[Dataset]]. * @group action - * @since 1.3.0 + * @since 1.6.0 */ def count(): Long = withCallback("count", groupBy().count()) { df => df.collect(needCallback = false).head.getLong(0) } /** - * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. - * @group dfops - * @since 1.3.0 + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * + * @group typedrel + * @since 1.6.0 */ def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { Repartition(numPartitions, shuffle = true, logicalPlan) } /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting DataFrame is hash partitioned. + * Returns a new [[Dataset]] partitioned by the given partitioning expressions into + * `numPartitions`. The resulting Datasetis hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * - * @group dfops - * @since 1.6.0 + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan { @@ -1858,13 +2054,13 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] partitioned by the given partitioning expressions preserving - * the existing number of partitions. The resulting DataFrame is hash partitioned. + * Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving + * the existing number of partitions. The resulting Datasetis hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * - * @group dfops - * @since 1.6.0 + * @group typedrel + * @since 2.0.0 */ @scala.annotation.varargs def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan { @@ -1872,29 +2068,35 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of * the 100 new partitions will claim 10 of the current partitions. + * * @group rdd - * @since 1.4.0 + * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { Repartition(numPartitions, shuffle = false, logicalPlan) } /** - * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. + * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. * This is an alias for `dropDuplicates`. - * @group dfops - * @since 1.3.0 + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * @group typedrel + * @since 2.0.0 */ def distinct(): Dataset[T] = dropDuplicates() /** - * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def persist(): this.type = { sqlContext.cacheManager.cacheQuery(this) @@ -1902,19 +2104,21 @@ class Dataset[T] private[sql]( } /** - * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def cache(): this.type = persist() /** - * Persist this [[DataFrame]] with the given storage level. + * Persist this [[Dataset]] with the given storage level. * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, * `MEMORY_AND_DISK_2`, etc. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def persist(newLevel: StorageLevel): this.type = { sqlContext.cacheManager.cacheQuery(this, None, newLevel) @@ -1922,10 +2126,12 @@ class Dataset[T] private[sql]( } /** - * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * * @param blocking Whether to block until all blocks are deleted. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { sqlContext.cacheManager.tryUncacheQuery(this, blocking) @@ -1933,51 +2139,47 @@ class Dataset[T] private[sql]( } /** - * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def unpersist(): this.type = unpersist(blocking = false) - ///////////////////////////////////////////////////////////////////////////// - // I/O - ///////////////////////////////////////////////////////////////////////////// - /** - * Represents the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. Note that the RDD is + * Represents the content of the [[Dataset]] as an [[RDD]] of [[Row]]s. Note that the RDD is * memoized. Once called, it won't change even if you change any query planning related Spark SQL * configurations (e.g. `spark.sql.shuffle.partitions`). + * * @group rdd - * @since 1.3.0 + * @since 1.6.0 */ lazy val rdd: RDD[T] = { - // use a local variable to make sure the map closure doesn't capture the whole DataFrame - val schema = this.schema queryExecution.toRdd.mapPartitions { rows => rows.map(boundTEncoder.fromRow) } } /** - * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. * @group rdd - * @since 1.3.0 + * @since 1.6.0 */ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() /** - * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. + * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. * @group rdd - * @since 1.3.0 + * @since 1.6.0 */ def javaRDD: JavaRDD[T] = toJavaRDD /** - * Registers this [[DataFrame]] as a temporary table using the given name. The lifetime of this - * temporary table is tied to the [[SQLContext]] that was used to create this DataFrame. + * Registers this [[Dataset]] as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SQLContext]] that was used to create this Dataset. * * @group basic - * @since 1.3.0 + * @since 1.6.0 */ def registerTempTable(tableName: String): Unit = { sqlContext.registerDataFrameAsTable(toDF(), tableName) @@ -1985,22 +2187,21 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Interface for saving the content of the [[DataFrame]] out into external storage or streams. + * Interface for saving the content of the [[Dataset]] out into external storage or streams. * * @group output - * @since 1.4.0 + * @since 1.6.0 */ @Experimental def write: DataFrameWriter = new DataFrameWriter(toDF()) /** - * Returns the content of the [[DataFrame]] as a RDD of JSON strings. - * @group rdd - * @since 1.3.0 + * Returns the content of the [[Dataset]] as a Dataset of JSON strings. + * @since 2.0.0 */ def toJSON: Dataset[String] = { val rowSchema = this.schema - val rdd = queryExecution.toRdd.mapPartitions { iter => + val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter => val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) @@ -2022,14 +2223,17 @@ class Dataset[T] private[sql]( } } } - import sqlContext.implicits._ - rdd.toDS + import sqlContext.implicits.newStringEncoder + sqlContext.createDataset(rdd) } /** - * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply + * Returns a best-effort snapshot of the files that compose this Dataset. This method simply * asks each constituent BaseRelation for its respective files and takes the union of all results. * Depending on the source relations, this may not find all input files. Duplicates are removed. + * + * @group basic + * @since 2.0.0 */ def inputFiles: Array[String] = { val files: Seq[String] = logicalPlan.collect { @@ -2042,7 +2246,7 @@ class Dataset[T] private[sql]( } //////////////////////////////////////////////////////////////////////////// - // for Python API + // For Python API //////////////////////////////////////////////////////////////////////////// /** @@ -2060,8 +2264,12 @@ class Dataset[T] private[sql]( } } + //////////////////////////////////////////////////////////////////////////// + // Private Helpers + //////////////////////////////////////////////////////////////////////////// + /** - * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with + * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ private[sql] def withNewExecutionId[U](body: => U): U = { @@ -2069,7 +2277,7 @@ class Dataset[T] private[sql]( } /** - * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the + * Wrap a Dataset action to track the QueryExecution and time cost, then report to the * user-registered callback functions. */ private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = { @@ -2125,7 +2333,7 @@ class Dataset[T] private[sql]( Dataset.newDataFrame(sqlContext, logicalPlan) } - /** A convenient function to wrap a logical plan and produce a DataFrame. */ + /** A convenient function to wrap a logical plan and produce a Dataset. */ @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { new Dataset[T](sqlContext, logicalPlan, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 08097e9f02084..47b81c17a31dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql /** - * A container for a [[Dataset]], used for implicit conversions. + * A container for a [[Dataset]], used for implicit conversions in Scala. * * To use this, import implicit conversions in SQL: * {{{ @@ -32,4 +32,10 @@ case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. def toDS(): Dataset[T] = ds + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDF(): DataFrame = ds.toDF() + + def toDF(colNames: String*): DataFrame = ds.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala rename to sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a8700de135ce4..f0f96825e2683 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,18 +29,13 @@ import org.apache.spark.sql.execution.QueryExecution /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not - * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupBy` on an existing * [[Dataset]]. * - * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, - * making this change to the class hierarchy would break some function signatures. As such, this - * class should be considered a preview of the final API. Changes will be made to the interface - * after Spark 1.6. - * - * @since 1.6.0 + * @since 2.0.0 */ @Experimental -class GroupedDataset[K, V] private[sql]( +class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], val queryExecution: QueryExecution, @@ -62,18 +57,22 @@ class GroupedDataset[K, V] private[sql]( private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext - private def groupedData = - new GroupedData( - Dataset.newDataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + private def groupedData = { + new RelationalGroupedDataset( + Dataset.newDataFrame(sqlContext, logicalPlan), + groupingAttributes, + RelationalGroupedDataset.GroupByType) + } /** - * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified - * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the + * specified type. The mapping of key columns to the type follows the same rules as `as` on + * [[Dataset]]. * * @since 1.6.0 */ - def keyAs[L : Encoder]: GroupedDataset[L, V] = - new GroupedDataset( + def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = + new KeyValueGroupedDataset( encoderFor[L], unresolvedVEncoder, queryExecution, @@ -294,7 +293,7 @@ class GroupedDataset[K, V] private[sql]( * * @since 1.6.0 */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long])) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]())) /** * Applies the given function to each cogrouped data. For each unique group, the function will @@ -305,7 +304,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def cogroup[U, R : Encoder]( - other: GroupedDataset[K, U])( + other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit val uEncoder = other.unresolvedVEncoder Dataset[R]( @@ -329,7 +328,7 @@ class GroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def cogroup[U, R]( - other: GroupedDataset[K, U], + other: KeyValueGroupedDataset[K, U], f: CoGroupFunction[K, V, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala rename to sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 04d277bed20f2..521032a8b3a83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -30,19 +29,17 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType /** - * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. * * The main method is the agg function, which has multiple variants. This class also contains * convenience some first order statistics such as mean, sum for convenience. * - * @since 1.3.0 + * @since 2.0.0 */ -@Experimental -class GroupedData protected[sql]( +class RelationalGroupedDataset protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], - groupType: GroupedData.GroupType) { + groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { @@ -54,16 +51,16 @@ class GroupedData protected[sql]( val aliasedAgg = aggregates.map(alias) groupType match { - case GroupedData.GroupByType => + case RelationalGroupedDataset.GroupByType => Dataset.newDataFrame( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) - case GroupedData.RollupType => + case RelationalGroupedDataset.RollupType => Dataset.newDataFrame( df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.CubeType => + case RelationalGroupedDataset.CubeType => Dataset.newDataFrame( df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) - case GroupedData.PivotType(pivotCol, values) => + case RelationalGroupedDataset.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) Dataset.newDataFrame( df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) @@ -299,7 +296,7 @@ class GroupedData protected[sql]( * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): GroupedData = { + def pivot(pivotColumn: String): RelationalGroupedDataset = { // This is to prevent unintended OOM errors when the number of distinct values is large val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) // Get the distinct values of the column and sort them so its consistent @@ -340,14 +337,14 @@ class GroupedData protected[sql]( * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { groupType match { - case GroupedData.GroupByType => - new GroupedData( + case RelationalGroupedDataset.GroupByType => + new RelationalGroupedDataset( df, groupingExprs, - GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) - case _: GroupedData.PivotType => + RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) + case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => throw new UnsupportedOperationException("pivot is only supported after a groupBy") @@ -372,7 +369,7 @@ class GroupedData protected[sql]( * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { + def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } } @@ -381,13 +378,13 @@ class GroupedData protected[sql]( /** * Companion object for GroupedData. */ -private[sql] object GroupedData { +private[sql] object RelationalGroupedDataset { def apply( df: DataFrame, groupingExprs: Seq[Expression], - groupType: GroupType): GroupedData = { - new GroupedData(df, groupingExprs, groupType: GroupType) + groupType: GroupType): RelationalGroupedDataset = { + new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index e23d5e1261c39..fd814e0f28e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -147,75 +147,4 @@ abstract class SQLImplicits { */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - /** - * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(_sqlContext.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(_sqlContext.createDataFrame(data)) - } - - // Do NOT add more implicit conversions for primitive types. - // They are likely to break source compatibility by making existing implicit conversions - // ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index 1d1d7edb240dd..dbea8521be206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -34,6 +34,7 @@ public abstract class BufferedRowIterator { protected LinkedList currentRows = new LinkedList<>(); // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); + private long startTimeNs = System.nanoTime(); public boolean hasNext() throws IOException { if (currentRows.isEmpty()) { @@ -46,6 +47,14 @@ public InternalRow next() { return currentRows.remove(); } + /** + * Returns the elapsed time since this object is created. This object represents a pipeline so + * this is a measure of how long the pipeline has been running. + */ + public long durationMs() { + return (System.nanoTime() - startTimeNs) / (1000 * 1000); + } + /** * Initializes from array of iterators of InternalRow. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index e97c6be7f177a..b4348d39c2b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf @@ -195,19 +196,42 @@ private[sql] case class DataSourceScan( rdd :: Nil } + private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String, + dataType: DataType, nullable: Boolean): ExprCode = { + val javaType = ctx.javaType(dataType) + val value = ctx.getValue(columnVar, dataType, ordinal) + val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } + val valueVar = ctx.freshName("value") + val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" + val code = s"/* ${toCommentSafeString(str)} */\n" + (if (nullable) { + s""" + boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal); + $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + s"$javaType ${valueVar} = $value;" + }).trim + ExprCode(code, isNullVar, valueVar) + } + // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val input = ctx.freshName("input") val idx = ctx.freshName("batchIdx") + val rowidx = ctx.freshName("rowIdx") val batch = ctx.freshName("batch") // PhysicalRDD always just has one input ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") ctx.addMutableState("int", idx, s"$idx = 0;") + val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) + val columnAssigns = colVars.zipWithIndex.map { case (name, i) => + ctx.addMutableState(columnVectorClz, name, s"$name = null;") + s"$name = ${batch}.column($i);" } - val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val row = ctx.freshName("row") val numOutputRows = metricTerm(ctx, "numOutputRows") @@ -217,19 +241,22 @@ private[sql] case class DataSourceScan( // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know // here which path to use. Fix this. - ctx.INPUT_ROW = row ctx.currentVars = null - val columns1 = exprs.map(_.gen(ctx)) + val columns1 = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } val scanBatches = ctx.freshName("processBatches") ctx.addNewFunction(scanBatches, s""" | private void $scanBatches() throws java.io.IOException { | while (true) { | int numRows = $batch.numRows(); - | if ($idx == 0) $numOutputRows.add(numRows); + | if ($idx == 0) { + | ${columnAssigns.mkString("", "\n", "\n")} + | $numOutputRows.add(numRows); + | } | | while (!shouldStop() && $idx < numRows) { - | InternalRow $row = $batch.getRow($idx++); + | int $rowidx = $idx++; | ${consume(ctx, columns1).trim} | } | if (shouldStop()) return; @@ -243,9 +270,10 @@ private[sql] case class DataSourceScan( | } | }""".stripMargin) + val exprRows = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) ctx.INPUT_ROW = row ctx.currentVars = null - val columns2 = exprs.map(_.gen(ctx)) + val columns2 = exprRows.map(_.gen(ctx)) val inputRow = if (outputUnsafeRows) row else null val scanRows = ctx.freshName("processRows") ctx.addNewFunction(scanRows, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a392b5341244f..010ed7f5008eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Runs this query returning the result as an array. + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. */ - def executeCollect(): Array[InternalRow] = { - // Packing the UnsafeRows into byte array for faster serialization. - // The byte arrays are in the following format: - // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] - // - // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also - // compressed. - val byteArrayRdd = execute().mapPartitionsInternal { iter => + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext) { + while (iter.hasNext && (n < 0 || count < n)) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) + count += 1 } out.writeInt(-1) out.flush() out.close() Iterator(bos.toByteArray) } + } - // Collect the byte arrays back to driver, then decode them as UnsafeRows. + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { val nFields = schema.length - val results = ArrayBuffer[InternalRow]() + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + buffer += row + sizeOfNextRow = ins.readInt() + } + } + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[InternalRow] = { + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - results += row - sizeOfNextRow = ins.readInt() - } + decodeUnsafeRows(bytes, results) } results.toArray } @@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + } - res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size } - buf.toArray + if (buf.size > n) { + buf.take(n).toArray + } else { + buf.toArray + } } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 121b6d9e97d19..7841ff01f93c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _} -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.internal.SQLConf @@ -69,8 +69,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => - joins.LeftSemiJoinHash( - leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } @@ -80,8 +80,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object CanBroadcast { def unapply(plan: LogicalPlan): Option[LogicalPlan] = { - if (conf.autoBroadcastJoinThreshold > 0 && - plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { + if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) { Some(plan) } else { None @@ -101,10 +100,45 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling * performed. If both sides of the join are eligible to be broadcasted then the + * - Shuffle hash join: if the average size of a single partition is small enough to build a hash + * table. * - Sort merge: if the matching join keys are sortable. */ object EquiJoinSelection extends Strategy with PredicateHelper { + /** + * Matches a plan whose single partition should be small enough to build a hash table. + * + * Note: this assume that the number of partition is fixed, requires addtional work if it's + * dynamic. + */ + def canBuildHashMap(plan: LogicalPlan): Boolean = { + plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes + } + + /** + * Returns whether we should use shuffle hash join or not. + * + * We should only use shuffle hash join when: + * 1) any single partition of a small table could fit in memory. + * 2) the smaller table is much smaller (3X) than the other one. + */ + private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = { + canBuildHashMap(left) && muchSmaller(left, right) || + canBuildHashMap(right) && muchSmaller(right, left) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- Inner joins -------------------------------------------------------------------------- @@ -117,6 +151,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(joins.BroadcastHashJoin( leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) || + !RowOrdering.isOrderable(leftKeys) => + val buildSide = + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { + BuildRight + } else { + BuildLeft + } + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoin( @@ -134,6 +180,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { Seq(joins.BroadcastHashJoin( leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) + + case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right) + if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) || + !RowOrdering.isOrderable(leftKeys) => + Seq(joins.ShuffledHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => joins.SortMergeJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 67aef72dedaf2..e3c7d7209af18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} -import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.internal.SQLConf /** @@ -264,6 +264,10 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport override def treeChildren: Seq[SparkPlan] = Nil } +object WholeStageCodegen { + val PIPELINE_DURATION_METRIC = "duration" +} + /** * WholeStageCodegen compile a subtree of plans that support codegen together into single Java * function. @@ -301,6 +305,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override private[sql] lazy val metrics = Map( + "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, + WholeStageCodegen.PIPELINE_DURATION_METRIC)) + override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) @@ -339,6 +347,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup logDebug(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) + val durationMs = longMetric("pipelineTime") + val rdds = child.asInstanceOf[CodegenSupport].upstreams() assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") if (rdds.length == 1) { @@ -347,7 +357,11 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(Array(iter)) new Iterator[InternalRow] { - override def hasNext: Boolean = buffer.hasNext + override def hasNext: Boolean = { + val v = buffer.hasNext + if (!v) durationMs += buffer.durationMs() + v + } override def next: InternalRow = buffer.next() } } @@ -358,7 +372,11 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(Array(leftIter, rightIter)) new Iterator[InternalRow] { - override def hasNext: Boolean = buffer.hasNext + override def hasNext: Boolean = { + val v = buffer.hasNext + if (!v) durationMs += buffer.durationMs() + v + } override def next: InternalRow = buffer.next() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 2abfd14916e7b..cd769d013786a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -160,6 +160,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED} is " + + s"deprecated and will be ignored. Vectorized parquet reader will be used instead.") + Seq(Row(SQLConf.PARQUET_VECTORIZED_READER_ENABLED, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala new file mode 100644 index 0000000000000..468228053c964 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ParseModes.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +private[datasources] object ParseModes { + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index e848f423eb118..f3514cd14ce8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -33,9 +33,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.Logging -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader +import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -99,8 +99,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). - protected val enableUnsafeRowParquetReader: Boolean = - sqlContext.getConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key).toBoolean protected val enableVectorizedParquetReader: Boolean = sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean protected val enableWholestageCodegen: Boolean = @@ -174,19 +172,17 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( * fails (for example, unsupported schema), try with the normal reader. * TODO: plumb this through a different way? */ - if (enableUnsafeRowParquetReader && + if (enableVectorizedParquetReader && format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - val parquetReader: UnsafeRowParquetRecordReader = new UnsafeRowParquetRecordReader() + val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader() if (!parquetReader.tryInitialize( split.serializableHadoopSplit.value, hadoopAttemptContext)) { parquetReader.close() } else { reader = parquetReader.asInstanceOf[RecordReader[Void, V]] - if (enableVectorizedParquetReader) { - parquetReader.resultBatch() - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - if (enableWholestageCodegen) parquetReader.enableReturningBatches(); - } + parquetReader.resultBatch() + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + if (enableWholestageCodegen) parquetReader.enableReturningBatches() } } @@ -203,7 +199,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private[this] var finished = false override def hasNext: Boolean = { - if (context.isInterrupted) { + if (context.isInterrupted()) { throw new TaskKilledException } if (!finished && !havePair) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e009a37f2de72..95de02cf5c182 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.CompressionCodecs +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} private[sql] class CSVOptions( @transient private val parameters: Map[String, String]) @@ -62,7 +62,7 @@ private[sql] class CSVOptions( val delimiter = CSVTypeCast.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) @@ -101,26 +101,3 @@ private[sql] class CSVOptions( val rowSeparator = "\n" } - -private[csv] object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" - val DROP_MALFORMED_MODE = "DROPMALFORMED" - val FAIL_FAST_MODE = "FAILFAST" - - val DEFAULT = PERMISSIVE_MODE - - def isValidMode(mode: String): Boolean = { - mode.toUpperCase match { - case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true - case _ => false - } - } - - def isDropMalformedMode(mode: String): Boolean = mode.toUpperCase == DROP_MALFORMED_MODE - def isFailFastMode(mode: String): Boolean = mode.toUpperCase == FAIL_FAST_MODE - def isPermissiveMode(mode: String): Boolean = if (isValidMode(mode)) { - mode.toUpperCase == PERMISSIVE_MODE - } else { - true // We default to permissive is the mode string is not valid - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 0937a213c984f..945ed2c2113d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -40,6 +40,7 @@ private[sql] object InferSchema { configOptions: JSONOptions): StructType = { require(configOptions.samplingRatio > 0, s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") + val shouldHandleCorruptRecord = configOptions.permissive val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { @@ -50,21 +51,23 @@ private[sql] object InferSchema { val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) - iter.map { row => + iter.flatMap { row => try { Utils.tryWithResource(factory.createParser(row)) { parser => parser.nextToken() - inferField(parser, configOptions) + Some(inferField(parser, configOptions)) } } catch { + case _: JsonParseException if shouldHandleCorruptRecord => + Some(StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))) case _: JsonParseException => - StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) + None } } }.treeAggregate[DataType]( StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecords), - compatibleRootType(columnNameOfCorruptRecords)) + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord), + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -194,18 +197,21 @@ private[sql] object InferSchema { * Remove top-level ArrayType wrappers and merge the remaining schemas */ private def compatibleRootType( - columnNameOfCorruptRecords: String): (DataType, DataType) => DataType = { + columnNameOfCorruptRecords: String, + shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. - case (ArrayType(ty1, _), ty2) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => compatibleRootType(columnNameOfCorruptRecords)(ty1, ty2) + case (ArrayType(ty1, _), ty2) => + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => + compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) // If we see any other data type at the root level, we get records that cannot be // parsed. So, we use the struct as the data type and add the corrupt field to the schema. case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] => + case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => withCorruptField(struct, columnNameOfCorruptRecords) - case (o, struct: StructType) if !o.isInstanceOf[StructType] => + case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord => withCorruptField(struct, columnNameOfCorruptRecords) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index e59dbd6b3d438..93c3d47c1dcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} -import org.apache.spark.sql.execution.datasources.CompressionCodecs +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} /** * Options for the JSON data source. @@ -28,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.CompressionCodecs */ private[sql] class JSONOptions( @transient private val parameters: Map[String, String]) - extends Serializable { + extends Logging with Serializable { val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) @@ -49,6 +50,16 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + + val failFast = ParseModes.isFailFastMode(parseMode) + val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + val permissive = ParseModes.isPermissiveMode(parseMode) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 3252b6c77f888..00c14adf0704b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -34,7 +35,7 @@ import org.apache.spark.util.Utils private[json] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) -object JacksonParser { +object JacksonParser extends Logging { def parse( input: RDD[String], @@ -257,13 +258,20 @@ object JacksonParser { def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present - val row = new GenericMutableRow(schema.length) - for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { - require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String.fromString(record)) + if (configOptions.failFast) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: $record") + } + if (configOptions.dropMalformed) { + logWarning(s"Dropping malformed line: $record") + Nil + } else { + val row = new GenericMutableRow(schema.length) + for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { + require(schema(corruptIndex).dataType == StringType) + row.update(corruptIndex, UTF8String.fromString(record)) + } + Seq(row) } - - Seq(row) } val factory = new JsonFactory() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala index 1a5c6a66c484e..102a9356df311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -23,9 +23,8 @@ import scala.concurrent.duration._ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} -import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryNode} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.util.ThreadUtils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9eaadea1b11ff..df7ad48812051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -30,7 +30,11 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType /** - * An interface for exchanges. + * Base class for operators that exchange data among multiple threads or processes. + * + * Exchanges are the key class of operators that enable parallelism. Although the implementation + * differs significantly, the concept is similar to the exchange operator described in + * "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe. */ abstract class Exchange extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0b0f59c3e4634..8cc352863902c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -109,51 +109,6 @@ private[execution] trait UniqueHashedRelation extends HashedRelation { } } -/** - * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. - */ -private[joins] class GeneralHashedRelation( - private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) - extends HashedRelation with Externalizable { - - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) - - override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) - - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } - - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - } -} - - -/** - * A specialized [[HashedRelation]] that maps key into a single value. This implementation - * assumes the key is unique. - */ -private[joins] class UniqueKeyHashedRelation( - private var hashTable: JavaHashMap[InternalRow, InternalRow]) - extends UniqueHashedRelation with Externalizable { - - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) - - override def getValue(key: InternalRow): InternalRow = hashTable.get(key) - - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } - - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - } -} - - private[execution] object HashedRelation { /** @@ -162,51 +117,16 @@ private[execution] object HashedRelation { * Note: The caller should make sure that these InternalRow are different objects. */ def apply( + canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { - if (keyGenerator.isInstanceOf[UnsafeProjection]) { - return UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) - } - - // TODO: Use Spark's HashMap implementation. - val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) - var currentRow: InternalRow = null - - // Whether the join key is unique. If the key is unique, we can convert the underlying - // hash map into one specialized for this. - var keyIsUnique = true - - // Create a mapping of buildKeys -> rows - while (input.hasNext) { - currentRow = input.next() - val rowKey = keyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += currentRow - } - } - - if (keyIsUnique) { - val uniqHashTable = new JavaHashMap[InternalRow, InternalRow](hashTable.size) - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - uniqHashTable.put(entry.getKey, entry.getValue()(0)) - } - new UniqueKeyHashedRelation(uniqHashTable) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(input, keyGenerator, sizeEstimate) } else { - new GeneralHashedRelation(hashTable) + UnsafeHashedRelation( + input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } } } @@ -428,6 +348,7 @@ private[joins] object UnsafeHashedRelation { sizeEstimate: Int): HashedRelation = { // Use a Java hash table here because unsafe maps expect fixed size records + // TODO: Use BytesToBytesMap for memory efficiency val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of buildKeys -> rows @@ -683,11 +604,7 @@ private[execution] case class HashedRelationBroadcastMode( override def transform(rows: Array[InternalRow]): HashedRelation = { val generator = UnsafeProjection.create(keys, attributes) - if (canJoinKeyFitWithinLong) { - LongHashedRelation(rows.iterator, generator, rows.length) - } else { - HashedRelation(rows.iterator, generator, rows.length) - } + HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) } private lazy val canonicalizedKeys: Seq[Expression] = { @@ -703,4 +620,3 @@ private[execution] case class HashedRelationBroadcastMode( case _ => false } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala deleted file mode 100644 index fa549b4d51336..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.LeftSemi -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Build the right table's join keys into a HashedRelation, and iteratively go through the left - * table, to find if the join keys are in the HashedRelation. - */ -case class LeftSemiJoinHash( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: SparkPlan, - right: SparkPlan, - condition: Option[Expression]) extends BinaryNode with HashJoin { - - override val joinType = LeftSemi - override val buildSide = BuildRight - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator) - hashSemiJoin(streamIter, hashRelation, numOutputRows) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala new file mode 100644 index 0000000000000..5c4f1ef60fd08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.memory.MemoryMode +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics + +/** + * Performs a hash join of two child relations by first shuffling the data using the join keys. + */ +case class ShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends BinaryNode with HashJoin { + + override private[sql] lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + + override def outputPartitioning: Partitioning = joinType match { + case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftSemi => left.outputPartitioning + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case x => + throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType") + } + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + // try to acquire some memory for the hash table, it could trigger other operator to free some + // memory. The memory acquired here will mostly be used until the end of task. + val context = TaskContext.get() + val memoryManager = context.taskMemoryManager() + var acquired = 0L + var used = 0L + context.addTaskCompletionListener((t: TaskContext) => + memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) + ) + + val copiedIter = iter.map { row => + // It's hard to guess what's exactly memory will be used, we have a rough guess here. + // TODO: use BytesToBytesMap instead of HashMap for memory efficiency + // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers + val needed = 150 + row.getSizeInBytes + if (needed > acquired - used) { + val got = memoryManager.acquireExecutionMemory( + Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) + if (got < needed) { + throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + + "hash join, please use sort merge join by setting " + + "spark.sql.join.preferSortMergeJoin=true") + } + acquired += got + } + used += needed + // HashedRelation requires that the UnsafeRow should be separate objects. + row.copy() + } + + HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => + val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) + val joinedRow = new JoinedRow + joinType match { + case Inner => + hashJoin(streamIter, hashed, numOutputRows) + + case LeftSemi => + hashSemiJoin(streamIter, hashed, numOutputRows) + + case LeftOuter => + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + streamIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) + }) + + case RightOuter => + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + streamIter.flatMap(currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) + }) + + case x => + throw new IllegalArgumentException( + s"ShuffledHashJoin should not take $x as the JoinType") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 807b39ace6266..60bd8ea39a7af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -665,11 +665,11 @@ private[joins] class SortMergeJoinScanner( * An iterator for outputting rows in left outer join. */ private class LeftOuterIterator( - smjScanner: SortMergeJoinScanner, - rightNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) + smjScanner: SortMergeJoinScanner, + rightNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends OneSideOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) { @@ -681,13 +681,12 @@ private class LeftOuterIterator( * An iterator for outputting rows in right outer join. */ private class RightOuterIterator( - smjScanner: SortMergeJoinScanner, - leftNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) - extends OneSideOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { + smjScanner: SortMergeJoinScanner, + leftNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) + extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) { protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row) protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row) @@ -710,11 +709,11 @@ private class RightOuterIterator( * @param numOutputRows an accumulator metric for the number of rows output */ private abstract class OneSideOuterIterator( - smjScanner: SortMergeJoinScanner, - bufferedSideNullRow: InternalRow, - boundCondition: InternalRow => Boolean, - resultProj: InternalRow => InternalRow, - numOutputRows: LongSQLMetric) extends RowIterator { + smjScanner: SortMergeJoinScanner, + bufferedSideNullRow: InternalRow, + boundCondition: InternalRow => Boolean, + resultProj: InternalRow => InternalRow, + numOutputRows: LongSQLMetric) extends RowIterator { // A row to store the joined result, reused many times protected[this] val joinedRow: JoinedRow = new JoinedRow() @@ -777,14 +776,14 @@ private abstract class OneSideOuterIterator( } private class SortMergeFullOuterJoinScanner( - leftKeyGenerator: Projection, - rightKeyGenerator: Projection, - keyOrdering: Ordering[InternalRow], - leftIter: RowIterator, - rightIter: RowIterator, - boundCondition: InternalRow => Boolean, - leftNullRow: InternalRow, - rightNullRow: InternalRow) { + leftKeyGenerator: Projection, + rightKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + leftIter: RowIterator, + rightIter: RowIterator, + boundCondition: InternalRow => Boolean, + leftNullRow: InternalRow, + rightNullRow: InternalRow) { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftRow: InternalRow = _ private[this] var leftRowKey: InternalRow = _ @@ -950,10 +949,9 @@ private class SortMergeFullOuterJoinScanner( } private class FullOuterIterator( - smjScanner: SortMergeFullOuterJoinScanner, - resultProj: InternalRow => InternalRow, - numRows: LongSQLMetric -) extends RowIterator { + smjScanner: SortMergeFullOuterJoinScanner, + resultProj: InternalRow => InternalRow, + numRows: LongSQLMetric) extends RowIterator { private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow() override def advanceNext(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6b43d273fefde..7fa13907295b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -122,7 +122,7 @@ private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialVa private object LongSQLMetricParam extends LongSQLMetricParam(_.sum.toString, 0L) -private object StaticsLongSQLMetricParam extends LongSQLMetricParam( +private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam( (values: Seq[Long]) => { // This is a workaround for SPARK-11013. // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update @@ -140,6 +140,24 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam( s"\n$sum ($min, $med, $max)" }, -1L) +private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam( + (values: Seq[Long]) => { + // This is a workaround for SPARK-11013. + // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update + // it at the end of task and the value will be at least 0. + val validValues = values.filter(_ >= 0) + val Seq(sum, min, med, max) = { + val metric = if (validValues.length == 0) { + Seq.fill(4)(0L) + } else { + val sorted = validValues.sorted + Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1)) + } + metric.map(Utils.msDurationToString) + } + s"\n$sum ($min, $med, $max)" + }, -1L) + private[sql] object SQLMetrics { // Identifier for distinguishing SQL metrics from other accumulators @@ -168,15 +186,24 @@ private[sql] object SQLMetrics { // The final result of this metric in physical operator UI may looks like: // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) - createLongMetric(sc, s"$name total (min, med, max)", StaticsLongSQLMetricParam) + createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam) + } + + def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = { + // The final result of this metric in physical operator UI may looks like: + // duration(min, med, max): + // 5s (800ms, 1s, 2s) + createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam) } def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = { val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam) - val staticsSQLMetricParam = Utils.getFormattedClassName(StaticsLongSQLMetricParam) + val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam) + val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam) val metricParam = metricParamName match { case `longSQLMetricParam` => LongSQLMetricParam - case `staticsSQLMetricParam` => StaticsLongSQLMetricParam + case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam + case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam } metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 9c3145637d980..4f1b83715892f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -32,9 +32,7 @@ case class PythonUDF( children: Seq[Expression]) extends Expression with Unevaluable with NonSQLExpression with Logging { - override def toString: String = s"PythonUDF#$name(${children.mkString(", ")})" + override def toString: String = s"$name(${children.mkString(", ")})" override def nullable: Boolean = true - - override def sql: String = s"$name(${children.mkString(", ")})" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 8a36d3224003a..24a01f5be1771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegen} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -79,12 +79,19 @@ private[sql] object SparkPlanGraph { exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = { planInfo.nodeName match { case "WholeStageCodegen" => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val cluster = new SparkPlanGraphCluster( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, planInfo.simpleString, - mutable.ArrayBuffer[SparkPlanGraphNode]()) + mutable.ArrayBuffer[SparkPlanGraphNode](), + metrics) nodes += cluster + buildSparkPlanGraphNode( planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges) case "InputAdapter" => @@ -166,13 +173,26 @@ private[ui] class SparkPlanGraphCluster( id: Long, name: String, desc: String, - val nodes: mutable.ArrayBuffer[SparkPlanGraphNode]) - extends SparkPlanGraphNode(id, name, desc, Map.empty, Nil) { + val nodes: mutable.ArrayBuffer[SparkPlanGraphNode], + metrics: Seq[SQLPlanMetric]) + extends SparkPlanGraphNode(id, name, desc, Map.empty, metrics) { override def makeDotNode(metricsValue: Map[Long, String]): String = { + val duration = metrics.filter(_.name.startsWith(WholeStageCodegen.PIPELINE_DURATION_METRIC)) + val labelStr = if (duration.nonEmpty) { + require(duration.length == 1) + val id = duration(0).accumulatorId + if (metricsValue.contains(duration(0).accumulatorId)) { + name + "\n\n" + metricsValue.get(id).get + } else { + name + } + } else { + name + } s""" | subgraph cluster${id} { - | label="${StringEscapeUtils.escapeJava(name)}"; + | label="${StringEscapeUtils.escapeJava(labelStr)}"; | ${nodes.map(_.makeDotNode(metricsValue)).mkString(" \n")} | } """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9aabe2d0abe1c..3d1d5b120a783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -210,11 +210,11 @@ object SQLConf { val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts", defaultValue = Some(true), - doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed." + + doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + "When set to false, only one SQLContext/HiveContext is allowed to be created " + "through the constructor (new SQLContexts/HiveContexts created through newSession " + - "method is allowed). Please note that this conf needs to be set in Spark Conf. Once" + - "a SQLContext/HiveContext has been created, changing the value of this conf will not" + + "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " + + "a SQLContext/HiveContext has been created, changing the value of this conf will not " + "have effect.", isPublic = true) @@ -236,6 +236,11 @@ object SQLConf { doc = "When true, enable partition pruning for in-memory columnar tables.", isPublic = false) + val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin", + defaultValue = Some(true), + doc = "When true, prefer sort merge join over shuffle hash join.", + isPublic = false) + val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", defaultValue = Some(10 * 1024 * 1024), doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + @@ -247,8 +252,8 @@ object SQLConf { "spark.sql.defaultSizeInBytes", doc = "The default table size used in query planning. By default, it is set to a larger " + "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + - "by default the optimizer will not choose to broadcast a table unless it knows for sure its" + - "size is small enough.", + "by default the optimizer will not choose to broadcast a table unless it knows for sure " + + "its size is small enough.", isPublic = false) val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", @@ -270,7 +275,7 @@ object SQLConf { doc = "The advisory minimal number of post-shuffle partitions provided to " + "ExchangeCoordinator. This setting is used in our test to make sure we " + "have enough parallelism to expose issues that will not be exposed with a " + - "single partition. When the value is a non-positive value, this setting will" + + "single partition. When the value is a non-positive value, this setting will " + "not be provided to ExchangeCoordinator.", isPublic = false) @@ -340,11 +345,6 @@ object SQLConf { "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") - val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = booleanConf( - key = "spark.sql.parquet.enableUnsafeRowRecordReader", - defaultValue = Some(true), - doc = "Enables using the custom ParquetUnsafeRowRecordReader.") - val PARQUET_VECTORIZED_READER_ENABLED = booleanConf( key = "spark.sql.parquet.enableVectorizedReader", defaultValue = Some(true), @@ -391,7 +391,7 @@ object SQLConf { // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", - doc = "Set a Fair Scheduler pool for a JDBC client session") + doc = "Set a Fair Scheduler pool for a JDBC client session.") val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", defaultValue = Some(200), @@ -433,7 +433,12 @@ object SQLConf { val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled", defaultValue = Some(true), - doc = "When false, we will treat bucketed table as normal table") + doc = "When false, we will treat bucketed table as normal table.") + + val ORDER_BY_ORDINAL = booleanConf("spark.sql.orderByOrdinal", + defaultValue = Some(true), + doc = "When true, the ordinal numbers are treated as the position in the select list. " + + "When false, the ordinal numbers in order/sort By clause are ignored.") // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. @@ -482,7 +487,7 @@ object SQLConf { val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, - doc = "When true, we could use `datasource`.`path` as table in SQL query" + doc = "When true, we could use `datasource`.`path` as table in SQL query." ) val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", @@ -501,7 +506,7 @@ object SQLConf { val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", defaultValue = Some(true), doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + - " method", + " method.", isPublic = false) val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes", @@ -511,7 +516,7 @@ object SQLConf { val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), - doc = "When true, the planner will try to find out duplicated exchanges and re-use them", + doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", isPublic = false) object Deprecated { @@ -522,6 +527,7 @@ object SQLConf { val CODEGEN_ENABLED = "spark.sql.codegen" val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED = "spark.sql.parquet.enableUnsafeRowRecordReader" } } @@ -586,6 +592,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) @@ -631,6 +639,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index ae9c8cc1ba9ff..189cc3972c9ba 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -145,12 +145,13 @@ public Row call(Person person) { Dataset df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - @Override - public String call(Row row) { - return row.getString(0) + "_" + row.get(1); - } - }).collect(); + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + .map(new Function() { + @Override + public String call(Row row) { + return row.getString(0) + "_" + row.get(1); + } + }).collect(); List expected = new ArrayList<>(2); expected.add("Michael_29"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index f3c5a86e20320..cf764c645f9ee 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -220,7 +220,8 @@ public void testCreateStructTypeFromList(){ StructType schema1 = StructType$.MODULE$.apply(fields1); Assert.assertEquals(0, schema1.fieldIndex("id")); - List fields2 = Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty())); + List fields2 = + Arrays.asList(new StructField("id", DataTypes.StringType, true, Metadata.empty())); StructType schema2 = StructType$.MODULE$.apply(fields2); Assert.assertEquals(0, schema2.fieldIndex("id")); } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 79b6e6176714f..4b8b0d9d4f8aa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -169,7 +169,7 @@ public Integer call(Integer v1, Integer v2) throws Exception { public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset grouped = ds.groupByKey(new MapFunction() { + KeyValueGroupedDataset grouped = ds.groupByKey(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -217,7 +217,7 @@ public String call(String v1, String v2) throws Exception { List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, Encoders.INT()); - GroupedDataset grouped2 = ds2.groupByKey(new MapFunction() { + KeyValueGroupedDataset grouped2 = ds2.groupByKey(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; @@ -249,7 +249,7 @@ public Iterator call(Integer key, Iterator left, Iterator data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, Encoders.STRING()); - GroupedDataset grouped = + KeyValueGroupedDataset grouped = ds.groupByKey(length(col("value"))).keyAs(Encoders.INT()); Dataset mapped = grouped.mapGroups( @@ -410,7 +410,7 @@ public void testTypedAggregation() { Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset> ds = context.createDataset(data, encoder); - GroupedDataset> grouped = ds.groupByKey( + KeyValueGroupedDataset> grouped = ds.groupByKey( new MapFunction, String>() { @Override public String call(Tuple2 value) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 199e138abfdc2..d03597ee5dcad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -166,22 +166,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + test("Star Expansion - CreateStruct and CreateArray") { + val structDf = testData2.select("a", "b").as("record") + // CreateStruct and CreateArray in aggregateExpressions + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + + // CreateStruct and CreateArray in project list (unresolved alias) + assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + + // CreateStruct and CreateArray in project list (alias) + assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + } + + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { df.explode($"*") { case Row(prefix: String, csv: String) => csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq }.queryExecution.assertAnalyzed() } - assert(e.getMessage.contains( - "Cannot explode *, explode can only be applied on a specific column.")) + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) } - test("explode alias and star") { + test("Star Expansion - explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") checkAnswer( @@ -612,7 +633,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val longString = Array.fill(21)("1").mkString val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ - ||_1 | + ||value | |+---------------------+ ||1 | ||111111111111111111111| @@ -620,7 +641,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { |""".stripMargin assert(df.showString(10, false) === expectedAnswerForFalse) val expectedAnswerForTrue = """+--------------------+ - || _1| + || value| |+--------------------+ || 1| ||11111111111111111...| diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d7fa23651bcee..04d3a25fcb4f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -case class OtherTuple(_1: String, _2: Int) - class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -636,8 +634,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(OuterObject.InnerClass("foo")).toDS(), OuterObject.InnerClass("foo")) } + + test("SPARK-14000: case class with tuple type field") { + checkDataset( + Seq(TupleClass((1, "a"))).toDS(), + TupleClass(1, "a") + ) + } } +case class OtherTuple(_1: String, _2: Int) + +case class TupleClass(data: (Int, String)) + class OuterClass extends Serializable { case class InnerClass(a: String) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 03d6df8c28bef..dfffa4bc8b1c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -45,8 +45,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { - case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j + case j: ShuffledHashJoin => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: SortMergeJoin => j @@ -63,7 +63,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), @@ -434,7 +434,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } @@ -460,7 +460,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[LeftSemiJoinHash]), + classOf[ShuffledHashJoin]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3efe984c09eb8..9f2233d5d821b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,8 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -1619,15 +1621,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10215 Div of Decimal returns null") { - val d = Decimal(1.12321) + val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") checkAnswer( df.selectExpr("b * a / b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) checkAnswer( df.selectExpr("b * a / b / b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) checkAnswer( df.selectExpr("b * a + b"), Seq(Row(BigDecimal(2.12321)))) @@ -1636,7 +1638,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal(0.12321)))) checkAnswer( df.selectExpr("b * a * b"), - Seq(Row(d.toBigDecimal))) + Seq(Row(d))) } test("precision smaller than scale") { @@ -2156,6 +2158,47 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("order by ordinal number") { + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC"), + sql("SELECT * FROM testData2 ORDER BY a DESC")) + // If the position is not an integer, ignore it. + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY b ASC")) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 SORT BY 1 DESC, 2"), + sql("SELECT * FROM testData2 SORT BY a DESC, b ASC")) + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) + } + + test("order by ordinal number - negative cases") { + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 0") + } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC") + } + intercept[UnresolvedException[SortOrder]] { + sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC") + } + } + + test("order by ordinal number with conf spark.sql.orderByOrdinal=false") { + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { + // If spark.sql.orderByOrdinal=false, ignore the position number. + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), + sql("SELECT * FROM testData2 ORDER BY b ASC")) + } + } + test("natural join") { val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") @@ -2179,4 +2222,68 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(4) :: Nil) } } + + test("join with using clause") { + val df1 = Seq(("r1c1", "r1c2", "t1r1c3"), + ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3") + val df2 = Seq(("r1c1", "r1c2", "t2r1c3"), + ("r2c1", "r2c2", "t2r2c3"), ("r3c1y", "r3c2", "t2r3c3")).toDF("c1", "c2", "c3") + val df3 = Seq((null, "r1c2", "t3r1c3"), + ("r2c1", "r2c2", "t3r2c3"), ("r3c1y", "r3c2", "t3r3c3")).toDF("c1", "c2", "c3") + withTempTable("t1", "t2", "t3") { + df1.registerTempTable("t1") + df2.registerTempTable("t2") + df3.registerTempTable("t3") + // inner join with one using column + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: Nil) + + // inner join with two using columns + checkAnswer( + sql("SELECT * FROM t1 join t2 using (c1, c2)"), + Row("r1c1", "r1c2", "t1r1c3", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "t2r2c3") :: Nil) + + // Left outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 left join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: Nil) + + // Right outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 right join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1y", null, null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with one using column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t2 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t2r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t2r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, + null, "r3c2", "t2r3c3") :: Nil) + + // Full outer join with null value in join column. + checkAnswer( + sql("SELECT * FROM t1 full outer join t3 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", null, null) :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t3r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", null, null) :: + Row("r3c1y", null, null, "r3c2", "t3r3c3") :: + Row(null, null, null, "r1c2", "t3r1c3") :: Nil) + + // Self join with using columns. + checkAnswer( + sql("SELECT * FROM t1 join t1 using (c1)"), + Row("r1c1", "r1c2", "t1r1c3", "r1c2", "t1r1c3") :: + Row("r2c1", "r2c2", "t1r2c3", "r2c2", "t1r2c3") :: + Row("r3c1x", "r3c2", "t1r3c3", "r3c2", "t1r3c3") :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index b6051b07c8093..0b1cb90186929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -42,7 +42,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc) - def runBenchmark(name: String, values: Int)(f: => Unit): Unit = { + def runBenchmark(name: String, values: Long)(f: => Unit): Unit = { val benchmark = new Benchmark(name, values) Seq(false, true).foreach { enabled => @@ -57,7 +57,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { // These benchmark are skipped in normal build ignore("range/filter/sum") { - val N = 500 << 20 + val N = 500L << 20 runBenchmark("rang/filter/sum", N) { sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() } @@ -71,7 +71,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("range/limit/sum") { - val N = 500 << 20 + val N = 500L << 20 runBenchmark("range/limit/sum", N) { sqlContext.range(N).limit(1000000).groupBy().sum().collect() } @@ -85,7 +85,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("stat functions") { - val N = 100 << 20 + val N = 100L << 20 runBenchmark("stddev", N) { sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() @@ -247,7 +247,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - ignore("rube") { + ignore("shuffle hash join") { + val N = 4 << 20 + sqlContext.setConf("spark.sql.shuffle.partitions", "2") + sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") + sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false") + runBenchmark("shuffle hash join", N) { + val df1 = sqlContext.range(N).selectExpr(s"id as k1") + val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X + shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X + */ + } + + ignore("cube") { val N = 5 << 20 runBenchmark("cube", N) { @@ -465,4 +485,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 9f159d1e1e8a8..9680f3a008a59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits._ test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 9cd50abda6f00..e9b65539b0d62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -143,7 +143,7 @@ class PlannerSuite extends SharedSQLContext { val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(sortMergeJoins.isEmpty, "Should not use sort merge join") + assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") sqlContext.clearCache() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index cb6d68dc3ac46..778477660e169 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.types._ * sorted by a reference implementation ([[ReferenceSort]]). */ class SortSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder test("basic sorting using ExternalSort") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 0940878e383df..9e04caf8ba7d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("decimal type") { // Casting is required here because ScalaReflection can't capture decimal precision information. val df = (1 to 10) - .map(i => Tuple1(Decimal(i, 15, 10))) + .map(i => Tuple1(Decimal(i, 15, 10).toJavaBigDecimal)) .toDF("dec") .select($"dec" cast DecimalType(15, 10)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6d942c4c90289..0a5699b99cf0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -963,7 +964,56 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } - test("Corrupt records") { + test("Corrupt records: FAILFAST mode") { + val schema = StructType( + StructField("a", StringType, true) :: Nil) + // `FAILFAST` mode should throw an exception for corrupt records. + val exceptionOne = intercept[SparkException] { + sqlContext.read + .option("mode", "FAILFAST") + .json(corruptRecords) + .collect() + } + assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + + val exceptionTwo = intercept[SparkException] { + sqlContext.read + .option("mode", "FAILFAST") + .schema(schema) + .json(corruptRecords) + .collect() + } + assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + } + + test("Corrupt records: DROPMALFORMED mode") { + val schemaOne = StructType( + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + val schemaTwo = StructType( + StructField("a", StringType, true) :: Nil) + // `DROPMALFORMED` mode should skip corrupt records + val jsonDFOne = sqlContext.read + .option("mode", "DROPMALFORMED") + .json(corruptRecords) + checkAnswer( + jsonDFOne, + Row("str_a_4", "str_b_4", "str_c_4") :: Nil + ) + assert(jsonDFOne.schema === schemaOne) + + val jsonDFTwo = sqlContext.read + .option("mode", "DROPMALFORMED") + .schema(schemaTwo) + .json(corruptRecords) + checkAnswer( + jsonDFTwo, + Row("str_a_4") :: Nil) + assert(jsonDFTwo.schema === schemaTwo) + } + + test("Corrupt records: PERMISSIVE mode") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { withTempTable("jsonTable") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index 29318d8b56053..88fcfce0ec1bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -36,7 +36,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) @@ -61,17 +61,17 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex data.repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader reader.initialize(file.asInstanceOf[String], null) val batch = reader.resultBatch() assert(reader.nextBatch()) assert(batch.numRows() == n) var i = 0 while (i < n) { - assert(batch.column(0).getIsNull(i)) - assert(batch.column(1).getIsNull(i)) - assert(batch.column(2).getIsNull(i)) - assert(batch.column(3).getIsNull(i)) + assert(batch.column(0).isNullAt(i)) + assert(batch.column(1).isNullAt(i)) + assert(batch.column(2).isNullAt(i)) + assert(batch.column(3).isNullAt(i)) i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index b394ffb366b88..51183e970d965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -57,7 +57,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -446,7 +446,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) @@ -520,7 +520,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11164: test the parquet filter in") { import testImplicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/table1" (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 4d9a8d7eb1b7d..ebdb105743ea6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -656,7 +656,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { var hash1: Int = 0 var hash2: Int = 0 (false :: true :: Nil).foreach { v => - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> v.toString) { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) { val df = sqlContext.read.parquet(dir.getCanonicalPath) val rows = df.queryExecution.toRdd.map(_.copy()).collect() val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) @@ -672,13 +672,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("UnsafeRowParquetRecordReader - direct path read") { - val data = (0 to 10).map(i => (i, ((i + 'a').toChar.toString))) + test("VectorizedParquetRecordReader - direct path read") { + val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) withTempPath { dir => sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(file, null) val result = mutable.ArrayBuffer.empty[(Int, String)] @@ -695,7 +695,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project just one column { - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(file, ("_2" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String)] @@ -711,7 +711,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Project columns in opposite order { - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) val result = mutable.ArrayBuffer.empty[(String, Int)] @@ -728,7 +728,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // Empty projection { - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(file, List[String]().asJava) var result = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index 15bf00e6f47e1..cc0cc65d3eb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -82,38 +82,17 @@ object ParquetReadBenchmark { } sqlBenchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(id) from tempTable").collect() - } - } - - sqlBenchmark.addCase("SQL Parquet Non-Vectorized") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { sqlContext.sql("select sum(id) from tempTable").collect() } } val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - // Driving the parquet reader directly without Spark. - parquetReaderBenchmark.addCase("ParquetReader Non-Vectorized") { num => - var sum = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader - reader.initialize(p, ("id" :: Nil).asJava) - - while (reader.nextKeyValue()) { - val record = reader.getCurrentValue.asInstanceOf[InternalRow] - if (!record.isNullAt(0)) sum += record.getInt(0) - } - reader.close() - } - } - // Driving the parquet reader in batch mode directly. parquetReaderBenchmark.addCase("ParquetReader Vectorized") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -122,7 +101,7 @@ object ParquetReadBenchmark { val numRows = batch.numRows() var i = 0 while (i < numRows) { - if (!col.getIsNull(i)) sum += col.getInt(i) + if (!col.isNullAt(i)) sum += col.getInt(i) i += 1 } } @@ -136,7 +115,7 @@ object ParquetReadBenchmark { parquetReaderBenchmark.addCase("ParquetReader Vectorized -> Row") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(p, ("id" :: Nil).asJava) val batch = reader.resultBatch() @@ -159,7 +138,6 @@ object ParquetReadBenchmark { ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 215 / 262 73.0 13.7 1.0X SQL Parquet MR 1946 / 2083 8.1 123.7 0.1X - SQL Parquet Non-Vectorized 1079 / 1213 14.6 68.6 0.2X */ sqlBenchmark.run() @@ -167,9 +145,8 @@ object ParquetReadBenchmark { Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz Parquet Reader Single Int Column Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - ParquetReader Non-Vectorized 610 / 737 25.8 38.8 1.0X - ParquetReader Vectorized 123 / 152 127.8 7.8 5.0X - ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 3.7X + ParquetReader Vectorized 123 / 152 127.8 7.8 1.0X + ParquetReader Vectorized -> Row 165 / 180 95.2 10.5 0.7X */ parquetReaderBenchmark.run() } @@ -191,32 +168,12 @@ object ParquetReadBenchmark { } benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect - } - } - - benchmark.addCase("SQL Parquet Non-vectorized") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect } } val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray - benchmark.addCase("ParquetReader Non-vectorized") { num => - var sum1 = 0L - var sum2 = 0L - files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader - reader.initialize(p, null) - while (reader.nextKeyValue()) { - val record = reader.getCurrentValue.asInstanceOf[InternalRow] - if (!record.isNullAt(0)) sum1 += record.getInt(0) - if (!record.isNullAt(1)) sum2 += record.getUTF8String(1).numBytes() - } - reader.close() - } - } /* Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz @@ -224,8 +181,6 @@ object ParquetReadBenchmark { ------------------------------------------------------------------------------------------- SQL Parquet Vectorized 628 / 720 16.7 59.9 1.0X SQL Parquet MR 1905 / 2239 5.5 181.7 0.3X - SQL Parquet Non-vectorized 1429 / 1732 7.3 136.3 0.4X - ParquetReader Non-vectorized 989 / 1357 10.6 94.3 0.6X */ benchmark.run() } @@ -247,7 +202,7 @@ object ParquetReadBenchmark { } benchmark.addCase("SQL Parquet MR") { iter => - withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { sqlContext.sql("select sum(length(c1)) from tempTable").collect } } @@ -293,7 +248,7 @@ object ParquetReadBenchmark { Read data column 191 / 250 82.1 12.2 1.0X Read partition column 82 / 86 192.4 5.2 2.3X Read both columns 220 / 248 71.5 14.0 0.9X - */ + */ benchmark.run() } } @@ -319,7 +274,7 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized") { num => var sum = 0 files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() @@ -340,7 +295,7 @@ object ParquetReadBenchmark { benchmark.addCase("PR Vectorized (Null Filtering)") { num => var sum = 0L files.map(_.asInstanceOf[String]).foreach { p => - val reader = new UnsafeRowParquetRecordReader + val reader = new VectorizedParquetRecordReader try { reader.initialize(p, ("c1" :: "c2" :: Nil).asJava) val batch = reader.resultBatch() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 04dd809df17ca..dd20855a81d9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -29,42 +29,6 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { - // Key is simply the record itself - private val keyProjection = new Projection { - override def apply(row: InternalRow): InternalRow = row - } - - test("GeneralHashedRelation") { - val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) - assert(hashed.isInstanceOf[GeneralHashedRelation]) - - assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) - assert(hashed.get(InternalRow(10)) === null) - - val data2 = CompactBuffer[InternalRow](data(2)) - data2 += data(2) - assert(hashed.get(data(2)) === data2) - } - - test("UniqueKeyHashedRelation") { - val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val hashed = HashedRelation(data.iterator, keyProjection) - assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - - assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) - assert(hashed.get(InternalRow(10)) === null) - - val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) === data(0)) - assert(uniqHashed.getValue(data(1)) === data(1)) - assert(uniqHashed.getValue(data(2)) === data(2)) - assert(uniqHashed.getValue(InternalRow(10)) === null) - } - test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 814e25d10e5cc..3cb3ef1ffa2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder private lazy val myUpperCaseData = sqlContext.createDataFrame( sparkContext.parallelize(Seq( @@ -101,6 +102,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) } + def makeShuffledHashJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { + val shuffledHashJoin = + joins.ShuffledHashJoin(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) + val filteredJoin = + boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) + EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + } + def makeSortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -136,6 +151,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 1c8b2ea808b30..4cacb20aa0791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -76,6 +76,22 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } + if (joinType != FullOuter) { + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + } + } + if (joinType != FullOuter) { test(s"$testName using BroadcastHashJoin") { val buildSide = joinType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 5eb6a745239ab..985a96f684541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -72,12 +72,13 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } - test(s"$testName using LeftSemiJoinHash") { + test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)), + ShuffledHashJoin( + leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 988852a4fc0b9..695b1824e8cf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -263,32 +263,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("LeftSemiJoinHash metrics") { + test("ShuffledHashJoin metrics") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> LeftSemiJoinHash(nodeId = 0) + // ... -> ShuffledHashJoin(nodeId = 0) val df = df1.join(df2, $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 1, Map( - 0L -> ("LeftSemiJoinHash", Map( + 0L -> ("ShuffledHashJoin", Map( "number of output rows" -> 2L))) ) } } - test("LeftSemiJoinBNL metrics") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> LeftSemiJoinBNL(nodeId = 0) - val df = df1.join(df2, $"key" < $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("LeftSemiJoinBNL", Map( - "number of output rows" -> 2L))) - ) - } - test("CartesianProduct metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.registerTempTable("testDataForJoin") @@ -321,7 +309,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) + assert(metricValues.values.toSeq.exists(_ === "2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 4641a1ad78920..09bd7f6e8f0a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -81,7 +81,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { test("basic") { def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { - assert(actual === expected.mapValues(_.toString)) + assert(actual.size == expected.size) + expected.foreach { e => + // The values in actual can be SQL metrics meaning that they contain additional formatting + // when converted to string. Verify that they start with the expected value. + // TODO: this is brittle. There is no requirement that the actual string needs to start + // with the accumulator value. + assert(actual.contains(e._1)) + val v = actual.get(e._1).get.trim + assert(v.startsWith(e._2.toString)) + } } val listener = new SQLListener(sqlContext.sparkContext.conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index fa2c74431ab45..4262097e8f81f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -68,7 +68,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.numNulls() == 4) reference.zipWithIndex.foreach { v => - assert(v._1 == column.getIsNull(v._2)) + assert(v._1 == column.isNullAt(v._2)) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.nullsNativeAddress() assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) @@ -489,10 +489,10 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == true) assert(batch.column(0).getInt(0) == 1) - assert(batch.column(0).getIsNull(0) == false) + assert(batch.column(0).isNullAt(0) == false) assert(batch.column(1).getDouble(0) == 1.1) - assert(batch.column(1).getIsNull(0) == false) - assert(batch.column(2).getIsNull(0) == true) + assert(batch.column(1).isNullAt(0) == false) + assert(batch.column(2).isNullAt(0) == true) assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 69bccfba4aa42..27e4cfc103bee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -421,10 +421,10 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte if (table.properties.get("spark.sql.sources.provider").isDefined) { val dataSourceTable = cachedDataSourceTables(qualifiedTableName) - val tableWithQualifiers = SubqueryAlias(qualifiedTableName.name, dataSourceTable) + val qualifiedTable = SubqueryAlias(qualifiedTableName.name, dataSourceTable) // Then, if alias is specified, wrap the table with a Subquery using the alias. // Otherwise, wrap the table with a Subquery using the table name. - alias.map(a => SubqueryAlias(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) } else if (table.tableType == CatalogTableType.VIRTUAL_VIEW) { val viewText = table.viewText.getOrElse(sys.error("Invalid view without text.")) alias match { @@ -935,7 +935,7 @@ private[hive] case class MetastoreRelation( HiveMetastoreTypes.toDataType(f.dataType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true - )(qualifiers = Seq(alias.getOrElse(tableName))) + )(qualifier = Some(alias.getOrElse(tableName))) } /** PartitionKey attributes */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index cd417ce3cca91..e54358e657690 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,9 +24,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.CollapseProject +import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -34,15 +33,6 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.execution.HiveScriptIOSchema import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType} -/** - * A place holder for generated SQL for subquery expression. - */ -case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable { - override def dataType: DataType = NullType - override def nullable: Boolean = true - override def sql: String = s"($query)" -} - /** * A builder class used to convert a resolved logical plan into a SQL query string. Note that not * all resolved logical plan are convertible. They either don't have corresponding SQL @@ -54,23 +44,26 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + private val nextSubqueryId = new AtomicLong(0) + private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" + def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) val outputNames = logicalPlan.output.map(_.name) - val qualifiers = logicalPlan.output.flatMap(_.qualifiers).distinct + val qualifiers = logicalPlan.output.flatMap(_.qualifier).distinct // Keep the qualifier information by using it as sub-query name, if there is only one qualifier // present. val finalName = if (qualifiers.length == 1) { qualifiers.head } else { - SQLBuilder.newSubqueryName + newSubqueryName() } // Canonicalizer will remove all naming information, we should add it back by adding an extra // Project and alias the outputs. val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map { - case (attr, name) => Alias(attr.withQualifiers(Nil), name)() + case (attr, name) => Alias(attr.withQualifier(None), name)() } val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan)) @@ -126,6 +119,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case w: Window => windowToSQL(w) + case g: Generate => + generateToSQL(g) + case Limit(limitExpr, child) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" @@ -250,6 +246,42 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def generateToSQL(g: Generate): String = { + val columnAliases = g.generatorOutput.map(_.sql).mkString(", ") + + val childSQL = if (g.child == OneRowRelation) { + // This only happens when we put UDTF in project list and there is no FROM clause. Because we + // always generate LATERAL VIEW for `Generate`, here we use a trick to put a dummy sub-query + // after FROM clause, so that we can generate a valid LATERAL VIEW SQL string. + // For example, if the original SQL is: "SELECT EXPLODE(ARRAY(1, 2))", we will convert in to + // LATERAL VIEW format, and generate: + // SELECT col FROM (SELECT 1) sub_q0 LATERAL VIEW EXPLODE(ARRAY(1, 2)) sub_q1 AS col + s"(SELECT 1) ${newSubqueryName()}" + } else { + toSQL(g.child) + } + + // The final SQL string for Generate contains 7 parts: + // 1. the SQL of child, can be a table or sub-query + // 2. the LATERAL VIEW keyword + // 3. an optional OUTER keyword + // 4. the SQL of generator, e.g. EXPLODE(array_col) + // 5. the table alias for output columns of generator. + // 6. the AS keyword + // 7. the column alias, can be more than one, e.g. AS key, value + // An concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder + // will put it in FROM clause later. + build( + childSQL, + "LATERAL VIEW", + if (g.outer) "OUTER" else "", + g.generator.sql, + newSubqueryName(), + "AS", + columnAliases + ) + } + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = output1.size == output2.size && output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) @@ -342,14 +374,19 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( - Batch("Collapse Project", FixedPoint(100), + Batch("Prepare", FixedPoint(100), // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - CollapseProject), + CollapseProject, + // Parser is unable to parse the following query: + // SELECT `u_1`.`id` + // FROM (((SELECT `t0`.`id` FROM `default`.`t0`) + // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) + // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 + // This rule combine adjacent Unions together so we can generate flat UNION ALL SQL string. + CombineUnions), Batch("Recover Scoping Info", Once, - // Remove all sub queries, as we will insert new ones when it's necessary. - EliminateSubqueryAliases, // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the // `Join` operator). However, this kind of plan can't be put under a sub query as we will // erase and assign a new qualifier to all outputs and make it impossible to distinguish @@ -358,6 +395,10 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve // ambiguity. NormalizedAttribute, + // Our analyzer will add one or more sub-queries above table relation, this rule removes + // these sub-queries so that next rule can combine adjacent table relation and sample to + // SQLTable. + RemoveSubqueriesAboveSQLTable, // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample` // operators on top of a table relation, merge the sample information into `SQLTable` of // that table relation, as we can only convert table sample to standard SQL string. @@ -370,9 +411,15 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object NormalizedAttribute extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { case a: AttributeReference => - AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifiers = Nil) + AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifier = None) case a: Alias => - Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifiers = Nil) + Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifier = None) + } + } + + object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t } } @@ -423,11 +470,22 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case j: Join => j.copy( left = addSubqueryIfNeeded(j.left), right = addSubqueryIfNeeded(j.right)) + + // A special case for Generate. When we put UDTF in project list, followed by WHERE, e.g. + // SELECT EXPLODE(arr) FROM tbl WHERE id > 1, the Filter operator will be under Generate + // operator and we need to add a sub-query between them, as it's not allowed to have a WHERE + // before LATERAL VIEW, e.g. "... FROM tbl WHERE id > 2 EXPLODE(arr) ..." is illegal. + case g @ Generate(_, _, _, _, _, f: Filter) => + // Add an extra `Project` to make sure we can generate legal SQL string for sub-query, + // for example, Subquery -> Filter -> Table will generate "(tbl WHERE ...) AS name", which + // misses the SELECT part. + val proj = Project(f.output, f) + g.copy(child = addSubquery(proj)) } } private def addSubquery(plan: LogicalPlan): SubqueryAlias = { - SubqueryAlias(SQLBuilder.newSubqueryName, plan) + SubqueryAlias(newSubqueryName(), plan) } private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match { @@ -437,6 +495,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case _: LocalLimit => plan case _: GlobalLimit => plan case _: SQLTable => plan + case _: Generate => plan case OneRowRelation => plan case _ => addSubquery(plan) } @@ -454,18 +513,21 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi object ExtractSQLTable { def unapply(plan: LogicalPlan): Option[SQLTable] = plan match { case l @ LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => - Some(SQLTable(database, table, l.output.map(_.withQualifiers(Nil)))) + Some(SQLTable(database, table, l.output.map(_.withQualifier(None)))) case m: MetastoreRelation => - Some(SQLTable(m.databaseName, m.tableName, m.output.map(_.withQualifiers(Nil)))) + Some(SQLTable(m.databaseName, m.tableName, m.output.map(_.withQualifier(None)))) case _ => None } } -} - -object SQLBuilder { - private val nextSubqueryId = new AtomicLong(0) - private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" + /** + * A place holder for generated SQL for subquery expression. + */ + case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable { + override def dataType: DataType = NullType + override def nullable: Boolean = true + override def sql: String = s"($query)" + } } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java index fc24600a1e4a7..a8cbd4fab15bb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/Complex.java @@ -39,7 +39,7 @@ * does not contain union fields that are not supported by Spark SQL. */ -@SuppressWarnings({"ALL", "unchecked"}) +@SuppressWarnings("all") public class Complex implements org.apache.thrift.TBase, java.io.Serializable, Cloneable { private static final org.apache.thrift.protocol.TStruct STRUCT_DESC = new org.apache.thrift.protocol.TStruct("Complex"); diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index ca46c229f1952..f6b9072da4449 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import scala.util.control.NonFatal +import org.apache.spark.sql.Column import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils @@ -45,12 +46,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) .write .saveAsTable("parquet_t2") + + def createArray(id: Column): Column = { + when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1)) + } + + sqlContext + .range(10) + .select( + createArray('id).as("arr"), + array(array('id), createArray('id)).as("arr2"), + lit("""{"f1": "1", "f2": "2", "f3": 3}""").as("json"), + 'id + ) + .write + .saveAsTable("parquet_t3") } override protected def afterAll(): Unit = { sql("DROP TABLE IF EXISTS parquet_t0") sql("DROP TABLE IF EXISTS parquet_t1") sql("DROP TABLE IF EXISTS parquet_t2") + sql("DROP TABLE IF EXISTS parquet_t3") sql("DROP TABLE IF EXISTS t0") } @@ -124,12 +141,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT * FROM t0 UNION SELECT * FROM t0") } - // Parser is unable to parse the following query: - // SELECT `u_1`.`id` - // FROM (((SELECT `t0`.`id` FROM `default`.`t0`) - // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) - // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 - ignore("three-child union") { + test("three-child union") { checkHiveQl( """ |SELECT id FROM parquet_t0 @@ -625,4 +637,103 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |HAVING MAX(a.KEY) > 0 """.stripMargin) } + + test("generator in project list without FROM clause") { + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3))") + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val") + } + + test("generator in project list with non-referenced table") { + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) FROM t0") + checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0") + } + + test("generator in project list with referenced table") { + checkHiveQl("SELECT EXPLODE(arr) FROM parquet_t3") + checkHiveQl("SELECT EXPLODE(arr) AS val FROM parquet_t3") + } + + test("generator in project list with non-UDTF expressions") { + checkHiveQl("SELECT EXPLODE(arr), id FROM parquet_t3") + checkHiveQl("SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3") + } + + test("generator in lateral view") { + checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val") + checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val") + } + + test("generator in lateral view with ambiguous names") { + checkHiveQl( + """ + |SELECT exp.id, parquet_t3.id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr) exp AS id + """.stripMargin) + checkHiveQl( + """ + |SELECT exp.id, parquet_t3.id + |FROM parquet_t3 + |LATERAL VIEW OUTER EXPLODE(arr) exp AS id + """.stripMargin) + } + + test("use JSON_TUPLE as generator") { + checkHiveQl( + """ + |SELECT c0, c1, c2 + |FROM parquet_t3 + |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt + """.stripMargin) + checkHiveQl( + """ + |SELECT a, b, c + |FROM parquet_t3 + |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c + """.stripMargin) + } + + test("nested generator in lateral view") { + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW EXPLODE(nested_array) exp1 AS val + """.stripMargin) + + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW OUTER EXPLODE(nested_array) exp1 AS val + """.stripMargin) + } + + test("generate with other operators") { + checkHiveQl( + """ + |SELECT EXPLODE(arr) AS val, id + |FROM parquet_t3 + |WHERE id > 2 + |ORDER BY val, id + |LIMIT 5 + """.stripMargin) + + checkHiveQl( + """ + |SELECT val, id + |FROM parquet_t3 + |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array + |LATERAL VIEW EXPLODE(nested_array) exp1 AS val + |WHERE val > 2 + |ORDER BY val, id + |LIMIT 5 + """.stripMargin) + } + + test("filter after subquery") { + checkHiveQl("SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 1468be4670f26..151aacbdd1c44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -230,7 +230,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") val shj = df.queryExecution.sparkPlan.collect { - case j: LeftSemiJoinHash => j + case j: ShuffledHashJoin => j } assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index d21bb573d491b..cfca93bbf0659 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -132,25 +132,7 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } - /** Used for testing [[SQLBuilder]] */ - private var numConvertibleQueries: Int = 0 - private var numTotalQueries: Int = 0 - override protected def afterAll(): Unit = { - logInfo({ - val percentage = if (numTotalQueries > 0) { - numConvertibleQueries.toDouble / numTotalQueries * 100 - } else { - 0D - } - - s"""SQLBuilder statistics: - |- Total query number: $numTotalQueries - |- Number of convertible queries: $numConvertibleQueries - |- Percentage of convertible queries: $percentage% - """.stripMargin - }) - try { TestHive.reset() } finally { @@ -412,32 +394,38 @@ abstract class HiveComparisonTest if (containsCommands) { originalQuery } else { - numTotalQueries += 1 + val convertedSQL = try { + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL + } catch { + case NonFatal(e) => fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$queryString + | + |# Resolved query plan: + |${originalQuery.analyzed.treeString} + """.stripMargin, e) + } + try { - val sql = new SQLBuilder(originalQuery.analyzed, TestHive).toSQL - numConvertibleQueries += 1 - logInfo( - s""" - |### Running SQL generation round-trip test {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - | - |Generated SQL: - |$sql - |}}} - """.stripMargin.trim) - new TestHive.QueryExecution(sql) - } catch { case NonFatal(e) => - logInfo( - s""" - |### Cannot convert the following logical plan back to SQL {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - |}}} - """.stripMargin.trim) - originalQuery + val queryExecution = new TestHive.QueryExecution(convertedSQL) + // Trigger the analysis of this converted SQL query. + queryExecution.analyzed + queryExecution + } catch { + case NonFatal(e) => fail( + s"""Failed to analyze the converted SQL string: + | + |# Original HiveQL query string: + |$queryString + | + |# Resolved query plan: + |${originalQuery.analyzed.treeString} + | + |# Converted SQL query string: + |$convertedSQL + """.stripMargin, e) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index ab4047df1ea3f..5fe85eaef2b55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -950,9 +950,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(checkAddFileRDD.first()) } - case class LogEntry(filename: String, message: String) - case class LogFile(name: String) - createQueryTest("dynamic_partition", """ |DROP TABLE IF EXISTS dynamic_part_table; @@ -1249,3 +1246,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { // for SPARK-2180 test case class HavingRow(key: Int, value: String, attr: Int) + +case class LogEntry(filename: String, message: String) +case class LogFile(name: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9667b53e48e40..2806b87f33618 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -729,7 +729,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-5203 union with different decimal precision") { - Seq.empty[(Decimal, Decimal)] + Seq.empty[(java.math.BigDecimal, java.math.BigDecimal)] .toDF("d1", "d2") .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") @@ -738,20 +738,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .queryExecution.analyzed } + test("Star Expansion - script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count()) + } + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count()) } test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count()) } test("test script transform data type") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index c395d361a1182..cc412241fb4da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -79,7 +79,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } test("Read/write all types with non-primitive type") { - val data = (0 to 255).map { i => + val data: Seq[AllDataTypesWithNonPrimitiveType] = (0 to 255).map { i => AllDataTypesWithNonPrimitiveType( s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, 0 until i, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index f811784b25c82..ace67a639c6b8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -28,11 +28,13 @@ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.io.ChunkedByteBuffer /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. * It contains information about the id of the blocks having this partition's data and * the corresponding record handle in the write ahead log that backs the partition. + * * @param index index of the partition * @param blockId id of the block having the partition data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark @@ -59,7 +61,6 @@ class WriteAheadLogBackedBlockRDDPartition( * correctness, and it can be used in situations where it is known that the block * does not exist in the Spark executors (e.g. after a failed driver is restarted). * - * * @param sc SparkContext * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data @@ -156,11 +157,12 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, dataRead, storageLevel) + blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + blockManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) + .asInstanceOf[Iterator[T]] } if (partition.isBlockIdValid) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 4880884b0509d..6d4f4b99c175f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -30,6 +30,7 @@ import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} +import org.apache.spark.util.io.ChunkedByteBuffer /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -84,7 +85,8 @@ private[streaming] class BlockManagerBasedBlockHandler( numRecords = countIterator.count putResult case ByteBufferBlock(byteBuffer) => - blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + blockManager.putBytes( + blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -178,15 +180,18 @@ private[streaming] class WriteAheadLogBasedBlockHandler( numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => - byteBuffer + new ChunkedByteBuffer(byteBuffer.duplicate()) case _ => throw new Exception(s"Could not push $blockId to block manager, unexpected block type") } // Store the block in block manager val storeInBlockManagerFuture = Future { - val putSucceeded = - blockManager.putBytes(blockId, serializedBlock, effectiveStorageLevel, tellMaster = true) + val putSucceeded = blockManager.putBytes( + blockId, + serializedBlock, + effectiveStorageLevel, + tellMaster = true) if (!putSucceeded) { throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") @@ -195,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - writeAheadLog.write(serializedBlock, clock.getTimeMillis()) + writeAheadLog.write(serializedBlock.toByteBuffer, clock.getTimeMillis()) } // Combine the futures, wait for both to complete, and return the write ahead log record handle diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 66448fd40057d..01f0c4de9e3c9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -349,7 +349,9 @@ private void testReduceByWindow(boolean withInverse) { JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), + new Duration(2000), + new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), new Duration(2000), new Duration(1000)); @@ -497,7 +499,8 @@ public JavaRDD call(JavaRDD in) { pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) { + @Override public JavaPairRDD call(JavaPairRDD in, + Time time) { return null; } } @@ -606,7 +609,8 @@ public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time ti pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, + Time time) { return null; } } @@ -616,7 +620,8 @@ public JavaRDD call(JavaRDD rdd1, JavaPairRDD stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, + Time time) { return null; } } @@ -624,9 +629,12 @@ public JavaPairRDD call(JavaRDD rdd1, JavaRDD r stream1.transformWithToPair( pairStream1, - new Function3, JavaPairRDD, Time, JavaPairRDD>() { + new Function3, JavaPairRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaPairRDD call(JavaRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -636,7 +644,8 @@ public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, + Time time) { return null; } } @@ -644,9 +653,12 @@ public JavaRDD call(JavaPairRDD rdd1, JavaRDD r pairStream1.transformWith( pairStream1, - new Function3, JavaPairRDD, Time, JavaRDD>() { + new Function3, JavaPairRDD, Time, + JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaRDD call(JavaPairRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -654,9 +666,12 @@ public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD, JavaRDD, Time, JavaPairRDD>() { + new Function3, JavaRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { + public JavaPairRDD call(JavaPairRDD rdd1, + JavaRDD rdd2, + Time time) { return null; } } @@ -664,9 +679,12 @@ public JavaPairRDD call(JavaPairRDD rdd1, JavaR pairStream1.transformWithToPair( pairStream2, - new Function3, JavaPairRDD, Time, JavaPairRDD>() { + new Function3, JavaPairRDD, Time, + JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { + public JavaPairRDD call(JavaPairRDD rdd1, + JavaPairRDD rdd2, + Time time) { return null; } } @@ -722,13 +740,16 @@ public JavaRDD call(List> listOfRDDs, Time time) { listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { @Override - public JavaPairRDD> call(List> listOfRDDs, Time time) { + public JavaPairRDD> call(List> listOfRDDs, + Time time) { Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD)listOfRDDs.get(0); JavaRDD rdd2 = (JavaRDD)listOfRDDs.get(1); - JavaRDD> rdd3 = (JavaRDD>)listOfRDDs.get(2); + JavaRDD> rdd3 = + (JavaRDD>)listOfRDDs.get(2); JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); - PairFunction mapToTuple = new PairFunction() { + PairFunction mapToTuple = + new PairFunction() { @Override public Tuple2 call(Integer i) { return new Tuple2<>(i, i); @@ -739,7 +760,8 @@ public Tuple2 call(Integer i) { } ); JavaTestUtils.attachTestOutputStream(transformed2); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } @@ -981,7 +1003,8 @@ public void testPairMap() { // Maps pair -> pair of different type new Tuple2<>(3, "new york"), new Tuple2<>(1, "new york"))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @@ -1014,7 +1037,8 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type new Tuple2<>(3, "new york"), new Tuple2<>(1, "new york"))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @@ -1044,7 +1068,8 @@ public void testPairMap2() { // Maps pair -> single Arrays.asList(1, 3, 4, 1), Arrays.asList(5, 5, 3, 1)); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( new Function, Integer>() { @@ -1116,7 +1141,8 @@ public void testPairGroupByKey() { new Tuple2<>("california", Arrays.asList("sharks", "ducks")), new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> grouped = pairStream.groupByKey(); @@ -1241,7 +1267,8 @@ public void testGroupByKeyAndWindow() { ) ); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream> groupWindowed = @@ -1255,7 +1282,8 @@ public void testGroupByKeyAndWindow() { } } - private static Set>> convert(List>> listOfTuples) { + private static Set>> + convert(List>> listOfTuples) { List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); @@ -1280,7 +1308,8 @@ public void testReduceByKeyAndWindow() { Arrays.asList(new Tuple2<>("california", 10), new Tuple2<>("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = @@ -1304,7 +1333,8 @@ public void testUpdateStateByKey() { Arrays.asList(new Tuple2<>("california", 14), new Tuple2<>("new york", 9))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream updated = pairStream.updateStateByKey( @@ -1347,7 +1377,8 @@ public void testUpdateStateByKeyWithInitial() { Arrays.asList(new Tuple2<>("california", 15), new Tuple2<>("new york", 11))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream updated = pairStream.updateStateByKey( @@ -1383,7 +1414,8 @@ public void testReduceByKeyAndWindowWithInverse() { Arrays.asList(new Tuple2<>("california", 10), new Tuple2<>("new york", 4))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = @@ -1630,19 +1662,27 @@ public void testCoGroup() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream, Iterable>> grouped = pairStream1.cogroup(pairStream2); + JavaPairDStream, Iterable>> grouped = + pairStream1.cogroup(pairStream2); JavaTestUtils.attachTestOutputStream(grouped); - List, Iterable>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List, Iterable>>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected.size(), result.size()); - Iterator, Iterable>>>> resultItr = result.iterator(); - Iterator, List>>>> expectedItr = expected.iterator(); + Iterator, Iterable>>>> resultItr = + result.iterator(); + Iterator, List>>>> expectedItr = + expected.iterator(); while (resultItr.hasNext() && expectedItr.hasNext()) { - Iterator, Iterable>>> resultElements = resultItr.next().iterator(); - Iterator, List>>> expectedElements = expectedItr.next().iterator(); + Iterator, Iterable>>> resultElements = + resultItr.next().iterator(); + Iterator, List>>> expectedElements = + expectedItr.next().iterator(); while (resultElements.hasNext() && expectedElements.hasNext()) { - Tuple2, Iterable>> resultElement = resultElements.next(); - Tuple2, List>> expectedElement = expectedElements.next(); + Tuple2, Iterable>> resultElement = + resultElements.next(); + Tuple2, List>> expectedElement = + expectedElements.next(); Assert.assertEquals(expectedElement._1(), resultElement._1()); equalIterable(expectedElement._2()._1(), resultElement._2()._1()); equalIterable(expectedElement._2()._2(), resultElement._2()._2()); @@ -1719,7 +1759,8 @@ public void testLeftOuterJoin() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream>> joined = pairStream1.leftOuterJoin(pairStream2); + JavaPairDStream>> joined = + pairStream1.leftOuterJoin(pairStream2); JavaDStream counted = joined.count(); JavaTestUtils.attachTestOutputStream(counted); List> result = JavaTestUtils.runStreams(ssc, 2, 2); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java index 67b2a0703e02b..ff0be820e0a9a 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -77,12 +77,14 @@ public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) } @Override - public void onOutputOperationStarted(JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + public void onOutputOperationStarted( + JavaStreamingListenerOutputOperationStarted outputOperationStarted) { super.onOutputOperationStarted(outputOperationStarted); } @Override - public void onOutputOperationCompleted(JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + public void onOutputOperationCompleted( + JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { super.onOutputOperationCompleted(outputOperationCompleted); } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 2d509af85ae33..122ca0627f720 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -34,12 +34,13 @@ import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.util.io.ChunkedByteBuffer class ReceivedBlockHandlerSuite extends SparkFunSuite @@ -155,7 +156,7 @@ class ReceivedBlockHandlerSuite val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) val bytes = reader.read(fileSegment) reader.close() - blockManager.dataDeserialize(generateBlockId(), bytes).toList + blockManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList } loggedData shouldEqual data } @@ -264,7 +265,8 @@ class ReceivedBlockHandlerSuite name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, + val serializerManager = new SerializerManager(serializer, conf) + val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") @@ -339,7 +341,7 @@ class ReceivedBlockHandlerSuite storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) - storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b)) }) + storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b).toByteBuffer) }) } /** Test error handling when blocks that cannot be stored */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 79ac833c1846b..c4bf42d0f272d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -223,7 +223,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(blockManager.dataSerialize(id, data.iterator)) + writer.write(blockManager.dataSerialize(id, data.iterator).toByteBuffer) } writer.close() segments diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 84445d60cd803..e941089d1b096 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -137,12 +137,9 @@ private[spark] class ApplicationMaster( System.setProperty("spark.master", "yarn") System.setProperty("spark.submit.deployMode", "cluster") - // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + // Set this internal configuration if it is running on cluster mode, this + // configuration will be checked in SparkContext to avoid misuse of yarn cluster mode. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) - - // Propagate the attempt if, so that in case of event logging, - // different attempt's logs gets created in different directory - System.setProperty("spark.yarn.app.attemptId", appAttemptId.getAttemptId().toString()) } logInfo("ApplicationAttemptId: " + appAttemptId) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 10cd6d00b0edb..0789567ae6a18 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -104,7 +104,7 @@ package object config { /* Cluster-mode launcher configuration. */ private[spark] val WAIT_FOR_APP_COMPLETION = ConfigBuilder("spark.yarn.submit.waitAppCompletion") - .doc("In cluster mode, whether to wait for the application to finishe before exiting the " + + .doc("In cluster mode, whether to wait for the application to finish before exiting the " + "launcher process.") .booleanConf .withDefault(true) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 0cc158b15a791..a8781636f25f4 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -96,11 +96,12 @@ private[spark] abstract class YarnSchedulerBackend( /** * Get the attempt ID for this run, if the cluster manager supports multiple * attempts. Applications run in client mode will not have attempt IDs. + * This attempt ID only includes attempt counter, like "1", "2". * * @return The application attempt id, if available. */ override def applicationAttemptId(): Option[String] = { - attemptId.map(_.toString) + attemptId.map(_.getAttemptId.toString) } /**