From f48420fde58d554480cc8830d2f8c4d17618f283 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 3 Sep 2014 18:57:20 -0700 Subject: [PATCH 1/5] [SPARK-2973][SQL] Lightweight SQL commands without distributed jobs when calling .collect() By overriding `executeCollect()` in physical plan classes of all commands, we can avoid to kick off a distributed job when collecting result of a SQL command, e.g. `sql("SET").collect()`. Previously, `Command.sideEffectResult` returns a `Seq[Any]`, and the `execute()` method in sub-classes of `Command` typically convert that to a `Seq[Row]` then parallelize it to an RDD. Now with this PR, `sideEffectResult` is required to return a `Seq[Row]` directly, so that `executeCollect()` can directly leverage that and be factored to the `Command` parent class. Author: Cheng Lian Closes #2215 from liancheng/lightweight-commands and squashes the following commits: 3fbef60 [Cheng Lian] Factored execute() method of physical commands to parent class Command 5a0e16c [Cheng Lian] Passes test suites e0e12e9 [Cheng Lian] Refactored Command.sideEffectResult and Command.executeCollect 995bdd8 [Cheng Lian] Cleaned up DescribeHiveTableCommand 542977c [Cheng Lian] Avoids confusion between logical and physical plan by adding package prefixes 55b2aa5 [Cheng Lian] Avoids distributed jobs when execution SQL commands --- .../apache/spark/sql/execution/commands.scala | 63 +++++++------------ .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 14 +++-- .../execution/DescribeHiveTableCommand.scala | 30 +++------ .../sql/hive/execution/NativeCommand.scala | 11 +--- .../spark/sql/hive/execution/commands.scala | 20 ++---- 6 files changed, 48 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 031b695169cea..286c6d264f86a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,11 +21,13 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { + this: SparkPlan => + /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -35,7 +37,11 @@ trait Command { * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] + protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } /** @@ -47,17 +53,17 @@ case class SetCommand( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + Array(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) } else { context.setConf(k, v) - Array(s"$k=$v") + Array(Row(s"$k=$v")) } // Query the value bound to key k. @@ -73,28 +79,22 @@ case class SetCommand( "hive-0.12.0.jar").mkString(":") Array( - "system:java.class.path=" + hiveJars, - "system:sun.java.command=shark.SharkServer2") - } - else { - Array(s"$k=${context.getConf(k, "")}") + Row("system:java.class.path=" + hiveJars), + Row("system:sun.java.command=shark.SharkServer2")) + } else { + Array(Row(s"$k=${context.getConf(k, "")}")) } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => context.getAllConfs.map { case (k, v) => - s"$k=$v" + Row(s"$k=$v") }.toSeq case _ => throw new IllegalArgumentException() } - def execute(): RDD[Row] = { - val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } - context.sparkContext.parallelize(rows, 1) - } - override def otherCopyArgs = context :: Nil } @@ -113,19 +113,14 @@ case class ExplainCommand( extends LeafNode with Command { // Run through the optimizer to generate the physical plan. - override protected[sql] lazy val sideEffectResult: Seq[String] = try { + override protected[sql] lazy val sideEffectResult: Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = context.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - outputString.split("\n") + outputString.split("\n").map(Row(_)) } catch { case cause: TreeNodeException[_] => - ("Error occurred during query planning: \n" + cause.getMessage).split("\n") - } - - def execute(): RDD[Row] = { - val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) - context.sparkContext.parallelize(explanation, 1) + ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) } override def otherCopyArgs = context :: Nil @@ -144,12 +139,7 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: } else { context.uncacheTable(tableName) } - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - context.emptyResult + Seq.empty[Row] } override def output: Seq[Attribute] = Seq.empty @@ -163,15 +153,8 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { - Seq(("# Registered as a temporary table", null, null)) ++ - child.output.map(field => (field.name, field.dataType.toString, null)) - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) - } - context.sparkContext.parallelize(rows, 1) + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + Row("# Registered as a temporary table", null, null) +: + child.output.map(field => Row(field.name, field.dataType.toString, null)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d9b2bc7348ad2..ced8397972fbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -389,7 +389,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -409,7 +409,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // be similar with Hive. describeHiveTableCommand.hiveString case command: PhysicalCommand => - command.sideEffectResult.map(_.toString) + command.sideEffectResult.map(_.head.toString) case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 47e24f0dec146..24abb1b5bd1a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,17 +18,19 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.Experimental -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} +import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} +import org.apache.spark.sql.hive +import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.{SQLContext, SchemaRDD} import scala.collection.JavaConversions._ @@ -196,9 +198,9 @@ private[hive] trait HiveStrategies { case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil - case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case hive.DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil - case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case hive.AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a40e89e0d382b..317801001c7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -41,26 +41,21 @@ case class DescribeHiveTableCommand( extends LeafNode with Command { // Strings with the format like Hive. It is used for result comparison in our unit tests. - lazy val hiveString: Seq[String] = { - val alignment = 20 - val delim = "\t" - - sideEffectResult.map { - case (name, dataType, comment) => - String.format("%-" + alignment + "s", name) + delim + - String.format("%-" + alignment + "s", dataType) + delim + - String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) - } + lazy val hiveString: Seq[String] = sideEffectResult.map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, Option(comment.asInstanceOf[String]).getOrElse("None")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") } - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil val columns: Seq[FieldSchema] = table.hiveQlTable.getCols val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (!partitionColumns.isEmpty) { + if (partitionColumns.nonEmpty) { val partColumnInfo = partitionColumns.map(field => (field.getName, field.getType, field.getComment)) results ++= @@ -74,14 +69,9 @@ case class DescribeHiveTableCommand( results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) } - results - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + results.map { case (name, dataType, comment) => + Row(name, dataType, comment) } - context.sparkContext.parallelize(rows, 1) } override def otherCopyArgs = context :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala index fe6031678f70f..8f10e1ba7f426 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -32,16 +32,7 @@ case class NativeCommand( @transient context: HiveContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) - - override def execute(): RDD[Row] = { - if (sideEffectResult.size == 0) { - context.emptyResult - } else { - val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) - context.sparkContext.parallelize(rows, 1) - } - } + override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) override def otherCopyArgs = context :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 2985169da033c..a1a4aa7de7bf7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -33,19 +33,13 @@ import org.apache.spark.sql.hive.HiveContext */ @DeveloperApi case class AnalyzeTable(tableName: String) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { hiveContext.analyze(tableName) - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + Seq.empty[Row] } } @@ -55,20 +49,14 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command { */ @DeveloperApi case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Any] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(None, tableName) - Seq.empty - } - - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + Seq.empty[Row] } } From 248067adbe90f93c7d5e23aa61b3072dfdf48a8a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 3 Sep 2014 18:59:26 -0700 Subject: [PATCH 2/5] [SPARK-2961][SQL] Use statistics to prune batches within cached partitions This PR is based on #1883 authored by marmbrus. Key differences: 1. Batch pruning instead of partition pruning When #1883 was authored, batched column buffer building (#1880) hadn't been introduced. This PR combines these two and provide partition batch level pruning, which leads to smaller memory footprints and can generally skip more elements. The cost is that the pruning predicates are evaluated more frequently (partition number multiplies batch number per partition). 1. More filters are supported Filter predicates consist of `=`, `<`, `<=`, `>`, `>=` and their conjunctions and disjunctions are supported. Author: Cheng Lian Closes #2188 from liancheng/in-mem-batch-pruning and squashes the following commits: 68cf019 [Cheng Lian] Marked sqlContext as @transient 4254f6c [Cheng Lian] Enables in-memory partition pruning in PartitionBatchPruningSuite 3784105 [Cheng Lian] Overrides InMemoryColumnarTableScan.sqlContext d2a1d66 [Cheng Lian] Disables in-memory partition pruning by default 062c315 [Cheng Lian] HiveCompatibilitySuite code cleanup 16b77bf [Cheng Lian] Fixed pruning predication conjunctions and disjunctions 16195c5 [Cheng Lian] Enabled both disjunction and conjunction 89950d0 [Cheng Lian] Worked around Scala style check 9c167f6 [Cheng Lian] Minor code cleanup 3c4d5c7 [Cheng Lian] Minor code cleanup ea59ee5 [Cheng Lian] Renamed PartitionSkippingSuite to PartitionBatchPruningSuite fc517d0 [Cheng Lian] More test cases 1868c18 [Cheng Lian] Code cleanup, bugfix, and adding tests cb76da4 [Cheng Lian] Added more predicate filters, fixed table scan stats for testing purposes 385474a [Cheng Lian] Merge branch 'inMemStats' into in-mem-batch-pruning --- .../catalyst/expressions/AttributeMap.scala | 41 ++ .../catalyst/expressions/BoundAttribute.scala | 12 +- .../scala/org/apache/spark/sql/SQLConf.scala | 7 + .../spark/sql/columnar/ColumnBuilder.scala | 10 +- .../spark/sql/columnar/ColumnStats.scala | 434 +++++------------- .../columnar/InMemoryColumnarTableScan.scala | 131 +++++- .../sql/columnar/NullableColumnBuilder.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 39 +- .../columnar/NullableColumnBuilderSuite.scala | 2 +- .../columnar/PartitionBatchPruningSuite.scala | 95 ++++ .../compression/BooleanBitSetSuite.scala | 4 +- .../compression/DictionaryEncodingSuite.scala | 2 +- .../compression/IntegralDeltaSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 4 +- .../TestCompressibleColumnBuilder.scala | 4 +- .../execution/HiveCompatibilitySuite.scala | 13 +- 17 files changed, 446 insertions(+), 359 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala new file mode 100644 index 0000000000000..8364379644c90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values + * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization + * of the name, or the expected nullability). + */ +object AttributeMap { + def apply[A](kvs: Seq[(Attribute, A)]) = + new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap) +} + +class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) + extends Map[Attribute, A] with Serializable { + + override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = + (baseMap.map(_._2) + kv).toMap + + override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator + + override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 54c6baf1af3bf..fa80b07f8e6be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } object BindReferences extends Logging { - def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { + + def bindReference[A <: Expression]( + expression: A, + input: Seq[Attribute], + allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + if (allowFailures) { + a + } else { + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + } } else { BoundReference(ordinal, a.dataType, a.nullable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 64d49354dadcd..4137ac7663739 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,7 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" + val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" @@ -124,6 +125,12 @@ trait SQLConf { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, partition pruning for in-memory columnar tables is enabled. + */ + private[spark] def inMemoryPartitionPruning: Boolean = + getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 247337a875c75..b3ec5ded22422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder { /** * Column statistics information */ - def columnStats: ColumnStats[_, _] + def columnStats: ColumnStats /** * Returns the final columnar byte buffer. @@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder { } private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( - val columnStats: ColumnStats[T, JvmType], + val columnStats: ColumnStats, val columnType: ColumnType[T, JvmType]) extends ColumnBuilder { @@ -81,18 +81,18 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 6502110e903fe..fc343ccb995c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,381 +17,193 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.types._ +private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + + val schema = Seq(lowerBound, upperBound, nullCount) +} + +private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { + val (forAttribute, schema) = { + val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) + (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) + } +} + /** * Used to collect statistical information when building in-memory columns. * * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { - /** - * Closed lower bound of this column. - */ - def lowerBound: JvmType - - /** - * Closed upper bound of this column. - */ - def upperBound: JvmType - +private[sql] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int) - - /** - * Returns `true` if `lower <= row(ordinal) <= upper`. - */ - def contains(row: Row, ordinal: Int): Boolean + def gatherStats(row: Row, ordinal: Int): Unit /** - * Returns `true` if `row(ordinal) < upper` holds. + * Column statistics represented as a single row, currently including closed lower bound, closed + * upper bound and null count. */ - def isAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower < row(ordinal)` holds. - */ - def isBelow(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `row(ordinal) <= upper` holds. - */ - def isAtOrAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower <= row(ordinal)` holds. - */ - def isAtOrBelow(row: Row, ordinal: Int): Boolean -} - -private[sql] sealed abstract class NativeColumnStats[T <: NativeType] - extends ColumnStats[T, T#JvmType] { - - type JvmType = T#JvmType - - protected var (_lower, _upper) = initialBounds - - def initialBounds: (JvmType, JvmType) - - protected def columnType: NativeColumnType[T] - - override def lowerBound: T#JvmType = _lower - - override def upperBound: T#JvmType = _upper - - override def isAtOrAbove(row: Row, ordinal: Int) = { - contains(row, ordinal) || isAbove(row, ordinal) - } - - override def isAtOrBelow(row: Row, ordinal: Int) = { - contains(row, ordinal) || isBelow(row, ordinal) - } + def collectedStatistics: Row } -private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] { - override def isAtOrBelow(row: Row, ordinal: Int) = true - - override def isAtOrAbove(row: Row, ordinal: Int) = true - - override def isBelow(row: Row, ordinal: Int) = true - - override def isAbove(row: Row, ordinal: Int) = true +private[sql] class NoopColumnStats extends ColumnStats { - override def contains(row: Row, ordinal: Int) = true + override def gatherStats(row: Row, ordinal: Int): Unit = {} - override def gatherStats(row: Row, ordinal: Int) {} - - override def upperBound = null.asInstanceOf[JvmType] - - override def lowerBound = null.asInstanceOf[JvmType] + override def collectedStatistics = Row() } -private[sql] abstract class BasicColumnStats[T <: NativeType]( - protected val columnType: NativeColumnType[T]) - extends NativeColumnStats[T] - -private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { - override def initialBounds = (true, false) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ByteColumnStats extends ColumnStats { + var upper = Byte.MinValue + var lower = Byte.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { - override def initialBounds = (Byte.MaxValue, Byte.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getByte(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { - override def initialBounds = (Short.MaxValue, Short.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ShortColumnStats extends ColumnStats { + var upper = Short.MinValue + var lower = Short.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class LongColumnStats extends BasicColumnStats(LONG) { - override def initialBounds = (Long.MaxValue, Long.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getShort(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { - override def initialBounds = (Double.MaxValue, Double.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class LongColumnStats extends ColumnStats { + var upper = Long.MinValue + var lower = Long.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { - override def initialBounds = (Float.MaxValue, Float.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getLong(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class DoubleColumnStats extends ColumnStats { + var upper = Double.MinValue + var lower = Double.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field + if (!row.isNullAt(ordinal)) { + val value = row.getDouble(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } -} -private[sql] object IntColumnStats { - val UNINITIALIZED = 0 - val INITIALIZED = 1 - val ASCENDING = 2 - val DESCENDING = 3 - val UNORDERED = 4 + def collectedStatistics = Row(lower, upper, nullCount) } -/** - * Statistical information for `Int` columns. More information is collected since `Int` is - * frequently used. Extra information include: - * - * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search - * is applicable when searching elements. - * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression - * scheme. - * - * (This two kinds of information are not used anywhere yet and might be removed later.) - */ -private[sql] class IntColumnStats extends BasicColumnStats(INT) { - import IntColumnStats._ - - private var orderedState = UNINITIALIZED - private var lastValue: Int = _ - private var _maxDelta: Int = _ - - def isAscending = orderedState != DESCENDING && orderedState != UNORDERED - def isDescending = orderedState != ASCENDING && orderedState != UNORDERED - def isOrdered = isAscending || isDescending - def maxDelta = _maxDelta - - override def initialBounds = (Int.MaxValue, Int.MinValue) +private[sql] class FloatColumnStats extends ColumnStats { + var upper = Float.MinValue + var lower = Float.MaxValue + var nullCount = 0 - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row.getFloat(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class IntColumnStats extends ColumnStats { + var upper = Int.MinValue + var lower = Int.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - - orderedState = orderedState match { - case UNINITIALIZED => - lastValue = field - INITIALIZED - - case INITIALIZED => - // If all the integers in the column are the same, ordered state is set to Ascending. - // TODO (lian) Confirm whether this is the standard behaviour. - val nextState = if (field >= lastValue) ASCENDING else DESCENDING - _maxDelta = math.abs(field - lastValue) - lastValue = field - nextState - - case ASCENDING if field < lastValue => - UNORDERED - - case DESCENDING if field > lastValue => - UNORDERED - - case state @ (ASCENDING | DESCENDING) => - _maxDelta = _maxDelta.max(field - lastValue) - lastValue = field - state - - case _ => - orderedState + if (!row.isNullAt(ordinal)) { + val value = row.getInt(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 } } + + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class StringColumnStats extends BasicColumnStats(STRING) { - override def initialBounds = (null, null) +private[sql] class StringColumnStats extends ColumnStats { + var upper: String = null + var lower: String = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 - } - } - - override def isAbove(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 + if (!row.isNullAt(ordinal)) { + val value = row.getString(ordinal) + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) { - override def initialBounds = (null, null) +private[sql] class TimestampColumnStats extends ColumnStats { + var upper: Timestamp = null + var lower: Timestamp = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Timestamp] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isAbove(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 - } - } - - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index cb055cd74a5e5..dc668e7dc934c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} @@ -31,23 +33,27 @@ object InMemoryRelation { new InMemoryRelation(child.output, useCompression, batchSize, child)() } +private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) + private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, child: SparkPlan) - (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) + (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { override lazy val statistics = Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + val partitionStatistics = new PartitionStatistics(output) + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output val cached = child.execute().mapPartitions { baseIterator => - new Iterator[Array[ByteBuffer]] { + new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) @@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - columnBuilders.map(_.build()) + val stats = Row.fromSeq( + columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + + CachedBatch(columnBuilders.map(_.build()), stats) } def hasNext = baseIterator.hasNext @@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers = cached } - override def children = Seq.empty override def newInstance() = { @@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], + predicates: Seq[Expression], relation: InMemoryRelation) extends LeafNode { + @transient override val sqlContext = relation.child.sqlContext + override def output: Seq[Attribute] = attributes + // Returned filter predicate should return false iff it is impossible for the input expression + // to evaluate to `true' based on statistics collected about this partition batch. + val buildFilter: PartialFunction[Expression, Expression] = { + case And(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) && buildFilter(rhs) + + case Or(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) || buildFilter(rhs) + + case EqualTo(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound + + case EqualTo(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound + + case LessThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case LessThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case LessThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + + case LessThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound + + case GreaterThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case GreaterThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case GreaterThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound + + case GreaterThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + } + + val partitionFilters = { + predicates.flatMap { p => + val filter = buildFilter.lift(p) + val boundFilter = + filter.map( + BindReferences.bindReference( + _, + relation.partitionStatistics.schema, + allowFailures = true)) + + boundFilter.foreach(_ => + filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f"))) + + // If the filter can't be resolved then we are missing required statistics. + boundFilter.filter(_.resolved) + } + } + + val readPartitions = sparkContext.accumulator(0) + val readBatches = sparkContext.accumulator(0) + + private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning + override def execute() = { + readPartitions.setValue(0) + readBatches.setValue(0) + relation.cachedColumnBuffers.mapPartitions { iterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + relation.partitionStatistics.schema) + // Find the ordinals of the requested columns. If none are requested, use the first. val requestedColumns = if (attributes.isEmpty) { Seq(0) @@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan( attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) } - iterator - .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_))) + val rows = iterator + // Skip pruned batches + .filter { cachedBatch => + if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + // Build column accessors + .map { cachedBatch => + requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + } + // Extract rows via column accessors .flatMap { columnAccessors => val nextRow = new GenericMutableRow(columnAccessors.length) new Iterator[Row] { @@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan( override def hasNext = columnAccessors.head.hasNext } } + + if (rows.hasNext) { + readPartitions += 1 + } + + rows } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index f631ee76fcd78..a72970eef7aa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { } abstract override def appendFrom(row: Row, ordinal: Int) { + columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) nulls.putInt(pos) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8dacb84c8a17e..7943d6e1b6fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { pruneFilterProject( projectList, filters, - identity[Seq[Expression]], // No filters are pushed down. - InMemoryColumnarTableScan(_, mem)) :: Nil + identity[Seq[Expression]], // All filters still need to be evaluated. + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 5f61fb5e16ea3..cde91ceb68c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN) - testColumnStats(classOf[ByteColumnStats], BYTE) - testColumnStats(classOf[ShortColumnStats], SHORT) - testColumnStats(classOf[IntColumnStats], INT) - testColumnStats(classOf[LongColumnStats], LONG) - testColumnStats(classOf[FloatColumnStats], FLOAT) - testColumnStats(classOf[DoubleColumnStats], DOUBLE) - testColumnStats(classOf[StringColumnStats], STRING) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP) - - def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( + testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + + def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + initialStatistics: Row) { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - assertResult(columnStats.initialBounds, "Wrong initial bounds") { - (columnStats.lowerBound, columnStats.upperBound) + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => + assert(actual === expected) } } @@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite { import ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() - val rows = Seq.fill(10)(makeRandomRow(columnType)) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) - assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index dc813fe146c47..a77262534a352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala new file mode 100644 index 0000000000000..5d2fd4959197c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.TestSQLContext._ + +case class IntegerData(i: Int) + +class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { + val originalColumnBatchSize = columnBatchSize + val originalInMemoryPartitionPruning = inMemoryPartitionPruning + + override protected def beforeAll() { + // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch + setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) + rawData.registerTempTable("intData") + + // Enable in-memory partition pruning + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + } + + override protected def afterAll() { + setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + } + + before { + cacheTable("intData") + } + + after { + uncacheTable("intData") + } + + // Comparisons + checkBatchPruning("i = 1", Seq(1), 1, 1) + checkBatchPruning("1 = i", Seq(1), 1, 1) + checkBatchPruning("i < 12", 1 to 11, 1, 2) + checkBatchPruning("i <= 11", 1 to 11, 1, 2) + checkBatchPruning("i > 88", 89 to 100, 1, 2) + checkBatchPruning("i >= 89", 89 to 100, 1, 2) + checkBatchPruning("12 > i", 1 to 11, 1, 2) + checkBatchPruning("11 >= i", 1 to 11, 1, 2) + checkBatchPruning("88 < i", 89 to 100, 1, 2) + checkBatchPruning("89 <= i", 89 to 100, 1, 2) + + // Conjunction and disjunction + checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) + checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) + checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + + // With unsupported predicate + checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) + checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + + def checkBatchPruning( + filter: String, + expectedQueryResult: Seq[Int], + expectedReadPartitions: Int, + expectedReadBatches: Int) { + + test(filter) { + val query = sql(s"SELECT * FROM intData WHERE $filter") + assertResult(expectedQueryResult.toArray, "Wrong query result") { + query.collect().map(_.head).toArray + } + + val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) + }.head + + assert(readBatches === expectedReadBatches, "Wrong number of read batches") + assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 5fba00480967c..e01cc8b4d20f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats} +import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends FunSuite { @@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite { // Tests encoder // ------------- - val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet) + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) val values = rows.map(_.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d8ae2a26778c9..d2969d906c943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 17619dcf974e3..322f447c24840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite { testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( - columnStats: NativeColumnStats[I], + columnStats: ColumnStats, columnType: NativeColumnType[I], scheme: IntegralDelta[I]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 40115beb98899..218c09ac26362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ class RunLengthEncodingSuite extends FunSuite { - testRunLengthEncoding(new BooleanColumnStats, BOOLEAN) + testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 72c19fa31d980..7db723d648d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ class TestCompressibleColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) extends NativeColumnBuilder(columnStats, columnType) @@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType]( object TestCompressibleColumnBuilder { def apply[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T], scheme: CompressionScheme) = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b589994bd25fa..ab487d673e813 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -35,26 +35,29 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private val originalUseCompression = TestHive.useCompression + private val originalColumnBatchSize = TestHive.columnBatchSize + private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { - // Enable in-memory columnar caching TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - // Enable in-memory columnar compression - TestHive.setConf(SQLConf.COMPRESS_CACHED, "true") + // Set a relatively small column batch size for testing purposes + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + // Enable in-memory partition pruning for testing purposes + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COMPRESS_CACHED, originalUseCompression.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ From c5cbc49233193836b321cb6b77ce69dae798570b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 19:08:39 -0700 Subject: [PATCH 3/5] [SPARK-3335] [SQL] [PySpark] support broadcast in Python UDF After this patch, broadcast can be used in Python UDF. Author: Davies Liu Closes #2243 from davies/udf_broadcast and squashes the following commits: 7b88861 [Davies Liu] support broadcast in UDF --- python/pyspark/sql.py | 17 +++++++------- python/pyspark/tests.py | 22 +++++++++++++++++++ .../apache/spark/sql/UdfRegistration.scala | 3 +++ .../spark/sql/execution/pythonUdfs.scala | 3 ++- 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 44316926ba334..aaa35dadc203e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -942,9 +942,7 @@ def __init__(self, sparkContext, sqlContext=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): @@ -953,7 +951,7 @@ def _ssql_ctx(self): Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -970,23 +968,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1a75cbff5c19..3e74799e82845 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -43,6 +43,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -525,6 +526,27 @@ def test_histogram(self): self.assertRaises(TypeError, lambda: rdd.histogram(2)) +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 0b48e9e659faa..0ea1105f082a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF @@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( @@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration { envVars, pythonIncludes, pythonExec, + broadcastVars, accumulator, dataType, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dc8be2456781..0977da3e8577c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -42,6 +42,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { @@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, - Seq[Broadcast[Array[Byte]]](), + udf.broadcastVars, udf.accumulator ).mapPartitions { iter => val pickle = new Unpickler From 7c6e71f05f4f5e0cd2d038ee81d1cda4a3e5cb39 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 3 Sep 2014 19:37:37 -0700 Subject: [PATCH 4/5] [SPARK-2435] Add shutdown hook to pyspark Author: Matthew Farrellee Closes #2183 from mattf/SPARK-2435 and squashes the following commits: ee0ee99 [Matthew Farrellee] [SPARK-2435] Add shutdown hook to pyspark --- python/pyspark/shell.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index e1e7cd954189f..fde3c29e5e790 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -28,6 +28,7 @@ sys.exit(1) +import atexit import os import platform import pyspark @@ -42,6 +43,7 @@ SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +atexit.register(lambda: sc.stop()) print("""Welcome to ____ __ From 1bed0a3869a526241381d2a74ba064e5b3721336 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 3 Sep 2014 20:47:00 -0700 Subject: [PATCH 5/5] [SPARK-3372] [MLlib] MLlib doesn't pass maven build / checkstyle due to multi-byte character contained in Gradient.scala Author: Kousuke Saruta Closes #2248 from sarutak/SPARK-3372 and squashes the following commits: 73a28b8 [Kousuke Saruta] Replaced UTF-8 hyphen with ascii hyphen --- .../scala/org/apache/spark/mllib/optimization/Gradient.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index fdd67160114ca..45dbf6044fcc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -128,7 +128,7 @@ class LeastSquaresGradient extends Gradient { class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { @@ -146,7 +146,7 @@ class HingeGradient extends Gradient { weights: Vector, cumGradient: Vector): Double = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) {