From f3b9b1017a9b93dbf07ccbfee3f140eab51ae1ea Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 21:30:02 -0700 Subject: [PATCH 01/18] add test cases --- .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..a5a90fd831fed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -406,4 +406,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + test("SPARK-14393: monotonically_increasing_id shouldn't change after coalesce") { + val df = spark.range(0, 4, 1, 4).withColumn("long_id", monotonically_increasing_id()) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce) + } + + test("SPARK-14393: monotonically_increasing_id shouldn't change after union") { + val df1 = spark.range(0, 2, 1, 2).withColumn("long_id", monotonically_increasing_id()) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("long_id", monotonically_increasing_id()) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2) + } } From 1a668586e2fd0c53c7c17d388eba256d1601339e Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 21:31:38 -0700 Subject: [PATCH 02/18] fix WholeStageCodegen --- .../sql/catalyst/expressions/MonotonicallyIncreasingID.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 5b4922e0cf2b7..1a7030a37b242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -70,7 +70,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") + s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; From 06a39e11f7a47a3e027e20043ef4ec524e1468a8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 22:09:15 -0700 Subject: [PATCH 03/18] add initializeStateForPartition to Projection --- .../spark/sql/catalyst/expressions/package.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1510a4796683c..e23eda3b3c549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -64,7 +64,15 @@ package object expressions { * column of the new row. If the schema of the input row is specified, then the given expression * will be bound to that schema. */ - abstract class Projection extends (InternalRow => InternalRow) + abstract class Projection extends (InternalRow => InternalRow) { + + /** + * Initialize internal state given the current partition index. + * This is used by non-deterministic expressions to set the initial state + * The default implementation does nothing. + */ + def initializeStateForPartition(partitionIndex: Int): Unit = {} + } /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each From 7840c95f556a46cf5934e66c3f8d99b053227577 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 22:15:01 -0700 Subject: [PATCH 04/18] add RDD.mapPartitionsWithIndexInternal --- .../main/scala/org/apache/spark/rdd/RDD.scala | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6dc334ceb52ea..f6e92569631dd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -787,14 +787,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { @@ -804,6 +816,7 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } + /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. From ccd2fe70a7e0d3f7f9a1bf9ce54b1a00a544d5cc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 22:39:20 -0700 Subject: [PATCH 05/18] fix issue without whole stage codegen --- .../expressions/codegen/GenerateUnsafeProjection.scala | 3 +++ .../org/apache/spark/sql/catalyst/expressions/package.scala | 6 +++--- .../apache/spark/sql/execution/basicPhysicalOperators.scala | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7cc45372daa5a..5225becdb0e74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -377,6 +377,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro public SpecificUnsafeProjection(Object[] references) { this.references = references; + } + + public void initializeStatesForPartition(int partitionIndex) { ${ctx.initMutableStates()} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index e23eda3b3c549..cd1abc7119492 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -67,11 +67,11 @@ package object expressions { abstract class Projection extends (InternalRow => InternalRow) { /** - * Initialize internal state given the current partition index. - * This is used by non-deterministic expressions to set the initial state + * Initialize internal states given the current partition index. + * This is used by non-deterministic expressions to set initial states. * The default implementation does nothing. */ - def initializeStateForPartition(partitionIndex: Int): Unit = {} + def initializeStatesForPartition(partitionIndex: Int): Unit = {} } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index dd78a784915d2..30ff60ec900c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) + project.initializeStatesForPartition(index) iter.map(project) } } From 1ca355e456da868eb2411c851b1e5b88779fd854 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Oct 2016 23:26:13 -0700 Subject: [PATCH 06/18] Nondeterministic.setInitialValues => initializeStatesForPartition(partitionIndex) --- .../sql/catalyst/expressions/Expression.scala | 6 ++--- .../catalyst/expressions/InputFileName.scala | 2 +- .../MonotonicallyIncreasingID.scala | 4 ++-- .../sql/catalyst/expressions/Projection.scala | 22 ++++++++++++------- .../expressions/SparkPartitionID.scala | 2 +- .../expressions/codegen/CodegenFallback.scala | 7 +----- .../sql/catalyst/expressions/predicates.scala | 4 ---- .../expressions/randomExpressions.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 2 +- .../CodegenExpressionCachingSuite.scala | 2 +- 10 files changed, 25 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index fa1a2ad56ccb3..f038a132d85a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -274,12 +274,12 @@ trait Nondeterministic extends Expression { private[this] var initialized = false - final def setInitialValues(): Unit = { - initInternal() + final def initializeStatesForPartition(partitionIndex: Int): Unit = { + initializeStatesForPartitionInternal(partitionIndex) initialized = true } - protected def initInternal(): Unit + protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit final override def eval(input: InternalRow = null): Any = { require(initialized, "nondeterministic expression should be initialized before evaluate") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 96929ecf56375..5c5f2e5b5e806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def prettyName: String = "input_file_name" - override protected def initInternal(): Unit = {} + override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 1a7030a37b242..6f0c1146a4ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis @transient private[this] var partitionMask: Long = _ - override protected def initInternal(): Unit = { + override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { count = 0L - partitionMask = TaskContext.getPartitionId().toLong << 33 + partitionMask = partitionIndex.toLong << 33 } override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a81fa1ce3adcc..190b530138e7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initializeStatesForPartition(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initializeStatesForPartition(partitionIndex) + case _ => + }) + } // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val buffer = new Array[Any](expressions.size) - expressions.foreach(_.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - }) + override def initializeStatesForPartition(partitionIndex: Int): Unit = { + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initializeStatesForPartition(partitionIndex) + case _ => + }) + } private[this] val exprArray = expressions.toArray private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 1f675d5b07270..61c8b31df297d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -38,7 +38,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" - override protected def initInternal(): Unit = { + override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { partitionId = TaskContext.getPartitionId() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6a5a3e7933eea..eb2c3d92cfb19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,11 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No trait CodegenFallback extends Expression { protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } - // LeafNode does not need `input` val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 799858a6865e5..4717f3f9d752f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,10 +31,6 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { - expression.foreach { - case n: Nondeterministic => n.setInitialValues() - case _ => - } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ca200768b2286..d3a4a2dafe55e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -42,7 +42,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { */ @transient protected var rng: XORShiftRandom = _ - override protected def initInternal(): Unit = { + override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { rng = new XORShiftRandom(seed + TaskContext.getPartitionId) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f0c149c02b9aa..a28d44cca5b53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.setInitialValues() + case n: Nondeterministic => n.initializeStatesForPartition(0) case _ => } expression.eval(inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 06dc3bd33b90e..e2acdd6166038 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -89,7 +89,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { */ case class NondeterministicExpression() extends LeafExpression with Nondeterministic with CodegenFallback { - override protected def initInternal(): Unit = { } + override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = false override def nullable: Boolean = false override def dataType: DataType = BooleanType From 9478fd6630c5ab31b9361c387292e739c4e78cbd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 00:28:31 -0700 Subject: [PATCH 07/18] test all code paths --- .../spark/sql/DataFrameFunctionsSuite.scala | 56 +++++++++++++++---- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a5a90fd831fed..ab50f77f10a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -407,19 +411,47 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-14393: monotonically_increasing_id shouldn't change after coalesce") { - val df = spark.range(0, 4, 1, 4).withColumn("long_id", monotonically_increasing_id()) - val rows = df.collect() - val rowsAfterCoalesce = df.coalesce(2).collect() - assert(rows === rowsAfterCoalesce) + + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { + import DataFrameFunctionsSuite.CodegenFallbackExpr + for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { + val c = if (codegenFallback) { + Column(CodegenFallbackExpr(v.expr)) + } else { + v + } + withSQLConf( + (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString), + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) { + val df = spark.range(0, 4, 1, 4).withColumn("c", c) + val rows = df.collect() + val rowsAfterCoalesce = df.coalesce(2).collect() + assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + + val df1 = spark.range(0, 2, 1, 2).withColumn("c", c) + val rows1 = df1.collect() + val df2 = spark.range(2, 4, 1, 2).withColumn("c", c) + val rows2 = df2.collect() + val rowsAfterUnion = df1.union(df2).collect() + assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " + + s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.") + } + } } - test("SPARK-14393: monotonically_increasing_id shouldn't change after union") { - val df1 = spark.range(0, 2, 1, 2).withColumn("long_id", monotonically_increasing_id()) - val rows1 = df1.collect() - val df2 = spark.range(2, 4, 1, 2).withColumn("long_id", monotonically_increasing_id()) - val rows2 = df2.collect() - val rowsAfterUnion = df1.union(df2).collect() - assert(rowsAfterUnion === rows1 ++ rows2) + test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + + "coalesce or union") { + assertValuesDoNotChangeAfterCoalesceOrUnion(monotonically_increasing_id()) + } +} + +object DataFrameFunctionsSuite { + case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = child.nullable + override def dataType: DataType = child.dataType + override lazy val resolved = true + override def eval(input: InternalRow): Any = child.eval(input) } } From bc4ea2c355e324e247028ab26e17a258948bddbd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 01:56:31 -0700 Subject: [PATCH 08/18] fixed codegen fallback case --- .../expressions/MonotonicallyIncreasingID.scala | 7 ++++--- .../expressions/codegen/CodeGenerator.scala | 14 ++++++++++++++ .../expressions/codegen/CodegenFallback.scala | 15 ++++++++++++++- .../codegen/GenerateUnsafeProjection.scala | 3 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 1 - 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 6f0c1146a4ba4..0a2d19b69a1b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") - ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - s"$partitionMaskTerm = ((long) partitionIndex) << 33;") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") + ctx.addPartitionInitializationStatements(s"$countTerm = 0L;") + ctx.addPartitionInitializationStatements(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6cab50ae1bf8d..c5f05e2e66a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -184,6 +184,20 @@ class CodegenContext { splitExpressions(initCodes, "init", Nil) } + /** + * Code statements to initialize states that depends on the partition index. + * An integer `partitionIndex` will be available within the scope. + */ + val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty + + def addPartitionInitializationStatements(statement: String): Unit = { + partitionInitializationStatements += statement + } + + def initPartition(): String = { + partitionInitializationStatements.mkString("\n") + } + /** * Holding all the functions those will be added into generated class. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index eb2c3d92cfb19..9248c9f0b98fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -29,6 +29,19 @@ trait CodegenFallback extends Expression { val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW val idx = ctx.references.length ctx.references += this + var childIndex = idx + this.foreach { + case n: Nondeterministic => + // This might add the current expression twice, but it won't hurt. + ctx.references += n + childIndex += 1 + ctx.addPartitionInitializationStatements( + s""" + |((Nondeterministic) references[$childIndex]) + | .initializeStatesForPartition(partitionIndex); + """.stripMargin) + case _ => + } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5225becdb0e74..1adad01662431 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -377,10 +377,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro public SpecificUnsafeProjection(Object[] references) { this.references = references; + ${ctx.initMutableStates()} } public void initializeStatesForPartition(int partitionIndex) { - ${ctx.initMutableStates()} + ${ctx.initPartition()} } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ab50f77f10a19..b913db658c334 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -411,7 +411,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } - private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 7ffe0ed88fc6e2bbd68b9fa807dacafbe80ed05b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 02:13:07 -0700 Subject: [PATCH 09/18] also initialize predicate --- .../codegen/GenerateMutableProjection.scala | 4 ++++ .../expressions/codegen/GeneratePredicate.scala | 11 +++++++++++ .../expressions/codegen/GenerateSafeProjection.scala | 4 ++++ .../spark/sql/execution/WholeStageCodegenExec.scala | 1 + 4 files changed, 20 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5c4b56b0b224c..9ffcf5794094a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } + public void initializeStatesForPartition(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public ${classOf[BaseMutableProjection].getName} target(InternalRow row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 39aa7b17de6c9..c7a37b10450f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -25,6 +25,13 @@ import org.apache.spark.sql.catalyst.expressions._ */ abstract class Predicate { def eval(r: InternalRow): Boolean + + /** + * Initialize internal states given the current partition index. + * This is used by non-deterministic expressions to set initial states. + * The default implementation does nothing. + */ + def initializeStatesForPartition(partitionIndex: Int): Unit = {} } /** @@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool ${ctx.initMutableStates()} } + public void initializeStatesForPartition(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public boolean eval(InternalRow ${ctx.INPUT_ROW}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 2773e1a666212..04482b5fd5c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } + public void initializeStatesForPartition(int partitionIndex) { + ${ctx.initPartition()} + } + ${ctx.declareAddedFunctions()} public java.lang.Object apply(java.lang.Object _i) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 62bf6f4a81eec..7992f3fbadb49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -327,6 +327,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co partitionIndex = index; this.inputs = inputs; ${ctx.initMutableStates()} + ${ctx.initPartition()} } ${ctx.declareAddedFunctions()} From 38dcb7aa81e94b0db81a9cc1be24e8e4524562a1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 02:21:00 -0700 Subject: [PATCH 10/18] fix other nondeterministic functions --- .../expressions/MonotonicallyIncreasingID.scala | 4 ++-- .../sql/catalyst/expressions/SparkPartitionID.scala | 7 +++---- .../catalyst/expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/codegen/CodegenFallback.scala | 2 +- .../sql/catalyst/expressions/randomExpressions.scala | 12 +++++++----- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 0a2d19b69a1b3..2a9cfb6b55327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -70,8 +70,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, "") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "") - ctx.addPartitionInitializationStatements(s"$countTerm = 0L;") - ctx.addPartitionInitializationStatements(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") + ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") + ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 61c8b31df297d..35d1c9744196c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} @@ -39,15 +38,15 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { - partitionId = TaskContext.getPartitionId() + partitionId = partitionIndex } override protected def evalInternal(input: InternalRow): Int = partitionId override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, - s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "") + ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c5f05e2e66a7e..bd1cdb6eba1fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -190,7 +190,7 @@ class CodegenContext { */ val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty - def addPartitionInitializationStatements(statement: String): Unit = { + def addPartitionInitializationStatement(statement: String): Unit = { partitionInitializationStatements += statement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 9248c9f0b98fa..3066a1570d3cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -35,7 +35,7 @@ trait CodegenFallback extends Expression { // This might add the current expression twice, but it won't hurt. ctx.references += n childIndex += 1 - ctx.addPartitionInitializationStatements( + ctx.addPartitionInitializationStatement( s""" |((Nondeterministic) references[$childIndex]) | .initializeStatesForPartition(partitionIndex); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index d3a4a2dafe55e..a649411edf562 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -43,7 +43,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { @transient protected var rng: XORShiftRandom = _ override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { - rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + rng = new XORShiftRandom(seed + partitionIndex) } override def nullable: Boolean = false @@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } @@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName - ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") + ctx.addMutableState(className, rngTerm, "") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } From da9d2619ebf69e8573a6352d39455883f882f86b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 02:26:16 -0700 Subject: [PATCH 11/18] test all nondeterministic functions --- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b913db658c334..455419214041f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import scala.util.Random + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -441,7 +443,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " + "coalesce or union") { - assertValuesDoNotChangeAfterCoalesceOrUnion(monotonically_increasing_id()) + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } } From 2ec32063a2fee26a42c74026eaf287189a4ec0ac Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 02:41:52 -0700 Subject: [PATCH 12/18] fix partition initialization in local relation --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e5e2cd7d27d15..ca45e213914d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) + projection.initializeStatesForPartition(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } From e2ebd88b796c560bfa5653491956bdeaa7021c86 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 02:53:48 -0700 Subject: [PATCH 13/18] minor --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 1 - .../apache/spark/sql/catalyst/expressions/Projection.scala | 2 -- .../spark/sql/catalyst/expressions/SparkPartitionID.scala | 4 ++-- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 4 ++-- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f6e92569631dd..d0c96d4b1475e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -816,7 +816,6 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } - /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 190b530138e7e..5554278c5d2a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. - * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -57,7 +56,6 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. - * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 35d1c9744196c..1d478e7972e2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -22,10 +22,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.types.{DataType, IntegerType} /** - * Expression that returns the current partition id of the Spark task. + * Expression that returns the current partition id. */ @ExpressionDescription( - usage = "_FUNC_() - Returns the current partition id of the Spark task", + usage = "_FUNC_() - Returns the current partition id", extended = "> SELECT _FUNC_();\n 0") case class SparkPartitionID() extends LeafExpression with Nondeterministic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bd1cdb6eba1fb..9c3c6d3b2a7f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -185,8 +185,8 @@ class CodegenContext { } /** - * Code statements to initialize states that depends on the partition index. - * An integer `partitionIndex` will be available within the scope. + * Code statements to initialize states that depend on the partition index. + * An integer `partitionIndex` will be made available within the scope. */ val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty From 0de225db6dde57c94fa90dee65f3f0b0a9d9c481 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 08:27:14 -0700 Subject: [PATCH 14/18] fix catalyst expression tests --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 3 +++ .../expressions/codegen/CodegenExpressionCachingSuite.scala | 2 ++ 2 files changed, 5 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a28d44cca5b53..90deb6e940508 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initializeStatesForPartition(0) val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { @@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { var plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initializeStatesForPartition(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) + plan.initializeStatesForPartition(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index e2acdd6166038..58e90aa119388 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -31,12 +31,14 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { // Use an Add to wrap two of them together in case we only initialize the top level expressions. val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = UnsafeProjection.create(Seq(expr)) + instance.initializeStatesForPartition(0) assert(instance.apply(null).getBoolean(0) === false) } test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr)) + instance.initializeStatesForPartition(0) assert(instance.apply(null).getBoolean(0) === false) } From 6659795fac07e5046af04407454b235c012b9ef6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 08:56:33 -0700 Subject: [PATCH 15/18] fix generated predicate --- .../catalyst/expressions/codegen/GeneratePredicate.scala | 7 +++---- .../codegen/CodegenExpressionCachingSuite.scala | 9 +++++---- .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 4 ++-- .../spark/sql/execution/basicPhysicalOperators.scala | 5 +++-- .../sql/execution/columnar/InMemoryTableScanExec.scala | 5 +++-- .../execution/joins/BroadcastNestedLoopJoinExec.scala | 2 +- .../spark/sql/execution/joins/CartesianProductExec.scala | 8 ++++---- .../org/apache/spark/sql/execution/joins/HashJoin.scala | 2 +- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 9 files changed, 23 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c7a37b10450f3..de89830509ee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -37,14 +37,14 @@ abstract class Predicate { /** * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]]. */ -object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] { +object GeneratePredicate extends CodeGenerator[Expression, Predicate] { protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) - protected def create(predicate: Expression): ((InternalRow) => Boolean) = { + protected def create(predicate: Expression): Predicate = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) @@ -78,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] - (r: InternalRow) => p.eval(r) + CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 58e90aa119388..50800f649645d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -45,7 +45,8 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GeneratePredicate should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GeneratePredicate.generate(expr) - assert(instance.apply(null) === false) + instance.initializeStatesForPartition(0) + assert(instance.eval(null) === false) } test("GenerateUnsafeProjection should not share expression instances") { @@ -75,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) - assert(instance1.apply(null) === false) + assert(instance1.eval(null) === false) val expr2 = MutableExpression() expr2.mutableState = true val instance2 = GeneratePredicate.generate(expr2) - assert(instance1.apply(null) === false) - assert(instance2.apply(null) === true) + assert(instance1.eval(null) === false) + assert(instance2.eval(null) === true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 48d6ef6dcd44a..38fd66301d553 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric @@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } protected def newPredicate( - expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { GeneratePredicate.generate(expression, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 30ff60ec900c3..7e1e6a72ae51c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -204,10 +204,11 @@ case class FilterExec(condition: Expression, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val predicate = newPredicate(condition, child.output) + predicate.initializeStatesForPartition(0) iter.filter { row => - val r = predicate(row) + val r = predicate.eval(row) if (r) numOutputRows += 1 r } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b87016d5a5696..10b348bba4adc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -132,10 +132,11 @@ case class InMemoryTableScanExec( val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitionsInternal { cachedBatchIterator => + buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) + partitionFilter.initializeStatesForPartition(index) // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = @@ -147,7 +148,7 @@ case class InMemoryTableScanExec( val cachedBatchesToScan = if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter(cachedBatch.stats)) { + if (!partitionFilter.eval(cachedBatch.stats)) { def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index bfe7e3dea45df..790d71b036cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec( @transient private lazy val boundCondition = { if (condition.isDefined) { - newPredicate(condition.get, streamed.output ++ broadcast.output) + newPredicate(condition.get, streamed.output ++ broadcast.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 15dc9b40662e2..53169682d08f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -98,15 +98,15 @@ case class CartesianProductExec( val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) - pair.mapPartitionsInternal { iter => + pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { - val boundCondition: (InternalRow) => Boolean = - newPredicate(condition.get, left.output ++ right.output) + val boundCondition = newPredicate(condition.get, left.output ++ right.output) + boundCondition.initializeStatesForPartition(index) val joined = new JoinedRow iter.filter { r => - boundCondition(joined(r._1, r._2)) + boundCondition.eval(joined(r._1, r._2)) } } else { iter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 05c5e2f4cd77b..1aef5f6864263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -81,7 +81,7 @@ trait HashJoin { UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _ } else { (r: InternalRow) => true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ecf7cf289f034..ca9c0ed8cec32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -101,7 +101,7 @@ case class SortMergeJoinExec( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => - newPredicate(cond, left.output ++ right.output) + newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } From 80f26c6e010b18e712e9514146975f21af9ac8ab Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Oct 2016 10:29:31 -0700 Subject: [PATCH 16/18] checked all mapPartitionsInternal usage --- .../apache/spark/sql/execution/DataSourceScanExec.scala | 6 ++++-- .../scala/org/apache/spark/sql/execution/ExistingRDD.scala | 3 ++- .../org/apache/spark/sql/execution/GenerateExec.scala | 3 ++- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 7 +++++-- .../sql/execution/joins/BroadcastNestedLoopJoinExec.scala | 5 +++-- .../scala/org/apache/spark/sql/execution/objects.scala | 6 ++++-- .../spark/sql/hive/execution/HiveTableScanExec.scala | 3 ++- 7 files changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index fdd1fa3648251..b2c0279956800 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -71,8 +71,9 @@ case class RowDataSourceScanExec( val unsafeRow = if (outputUnsafeRows) { rdd } else { - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initializeStatesForPartition(index) iter.map(proj) } } @@ -284,8 +285,9 @@ case class FileSourceScanExec( val unsafeRows = { val scan = inputRDD if (needsUnsafeRowConversion) { - scan.mapPartitionsInternal { iter => + scan.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initializeStatesForPartition(index) iter.map(proj) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d3a22228623e1..5415239c20190 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -165,8 +165,9 @@ case class RDDScanExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) + proj.initializeStatesForPartition(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 39189a2b0c72c..89a9d74f10fae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -91,8 +91,9 @@ case class GenerateExec( } val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsInternal { iter => + rows.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(output, output) + proj.initializeStatesForPartition(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 7992f3fbadb49..8241ab927f795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -380,10 +380,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co } else { // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => - val partitionIndex = TaskContext.getPartitionId() + Iterator((leftIter, rightIter)) + // a small hack to obtain the correct partition index + }.mapPartitionsWithIndex { (index, zippedIter) => + val (leftIter, rightIter) = zippedIter.next() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(partitionIndex, Array(leftIter, rightIter)) + buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 790d71b036cc8..6e5262e496456 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + private[this] def genResultProjection: UnsafeProjection = joinType match { case LeftExistence(j) => UnsafeProjection.create(output, output) case other => @@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec( } val numOutputRows = longMetric("numOutputRows") - resultRdd.mapPartitionsInternal { iter => + resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection + resultProj.initializeStatesForPartition(index) iter.map { r => numOutputRows += 1 resultProj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2acc5110e8950..e23b180430137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -85,8 +85,9 @@ case class DeserializeToObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + projection.initializeStatesForPartition(index) iter.map(projection) } } @@ -120,8 +121,9 @@ case class SerializeFromObjectExec( } override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = UnsafeProjection.create(serializer) + projection.initializeStatesForPartition(index) iter.map(projection) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 231f204b12b47..fb9ff7b4407db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -154,8 +154,9 @@ case class HiveTableScanExec( val numOutputRows = longMetric("numOutputRows") // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649) val outputSchema = schema - rdd.mapPartitionsInternal { iter => + rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(outputSchema) + proj.initializeStatesForPartition(index) iter.map { r => numOutputRows += 1 proj(r) From 553c6a543dd18a7278bf989e9197e74dc3cece7c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 1 Nov 2016 22:34:23 -0700 Subject: [PATCH 17/18] rename to initialize --- .../spark/sql/catalyst/expressions/Expression.scala | 10 ++++++---- .../spark/sql/catalyst/expressions/InputFileName.scala | 2 +- .../expressions/MonotonicallyIncreasingID.scala | 2 +- .../spark/sql/catalyst/expressions/Projection.scala | 10 ++++++---- .../sql/catalyst/expressions/SparkPartitionID.scala | 2 +- .../catalyst/expressions/codegen/CodegenFallback.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../expressions/codegen/GeneratePredicate.scala | 4 ++-- .../expressions/codegen/GenerateSafeProjection.scala | 2 +- .../expressions/codegen/GenerateUnsafeProjection.scala | 2 +- .../spark/sql/catalyst/expressions/package.scala | 2 +- .../sql/catalyst/expressions/randomExpressions.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- .../catalyst/expressions/ExpressionEvalHelper.scala | 8 ++++---- .../codegen/CodegenExpressionCachingSuite.scala | 8 ++++---- .../spark/sql/execution/DataSourceScanExec.scala | 4 ++-- .../org/apache/spark/sql/execution/ExistingRDD.scala | 2 +- .../org/apache/spark/sql/execution/GenerateExec.scala | 2 +- .../spark/sql/execution/basicPhysicalOperators.scala | 4 ++-- .../sql/execution/columnar/InMemoryTableScanExec.scala | 2 +- .../execution/joins/BroadcastNestedLoopJoinExec.scala | 2 +- .../sql/execution/joins/CartesianProductExec.scala | 2 +- .../scala/org/apache/spark/sql/execution/objects.scala | 4 ++-- .../spark/sql/hive/execution/HiveTableScanExec.scala | 2 +- 24 files changed, 44 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 13993303513c9..effb8cc6c68eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -272,17 +272,19 @@ trait Nondeterministic extends Expression { final override def deterministic: Boolean = false final override def foldable: Boolean = false + @transient private[this] var initialized = false - final def initializeStatesForPartition(partitionIndex: Int): Unit = { - initializeStatesForPartitionInternal(partitionIndex) + final def initialize(partitionIndex: Int): Unit = { + initializeInternal(partitionIndex) initialized = true } - protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit + protected def initializeInternal(partitionIndex: Int): Unit final override def eval(input: InternalRow = null): Any = { - require(initialized, "nondeterministic expression should be initialized before evaluate") + require(initialized, + s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.") evalInternal(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 5c5f2e5b5e806..b6c12c5351119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def prettyName: String = "input_file_name" - override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = {} + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { InputFileNameHolder.getInputFileName() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2a9cfb6b55327..72b8dcca26e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -50,7 +50,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis @transient private[this] var partitionMask: Long = _ - override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { count = 0L partitionMask = partitionIndex.toLong << 33 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 5554278c5d2a7..53bce29c686b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -30,9 +31,9 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - override def initializeStatesForPartition(partitionIndex: Int): Unit = { + override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { - case n: Nondeterministic => n.initializeStatesForPartition(partitionIndex) + case n: Nondeterministic => n.initialize(partitionIndex) case _ => }) } @@ -56,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { /** * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified * expressions. + * * @param expressions a sequence of expressions that determine the value of each column of the * output row. */ @@ -65,9 +67,9 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu private[this] val buffer = new Array[Any](expressions.size) - override def initializeStatesForPartition(partitionIndex: Int): Unit = { + override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { - case n: Nondeterministic => n.initializeStatesForPartition(partitionIndex) + case n: Nondeterministic => n.initialize(partitionIndex) case _ => }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 1d478e7972e2a..6bef473cac060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -37,7 +37,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override val prettyName = "SPARK_PARTITION_ID" - override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { partitionId = partitionIndex } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 3066a1570d3cb..0322d1dd6a9ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -38,7 +38,7 @@ trait CodegenFallback extends Expression { ctx.addPartitionInitializationStatement( s""" |((Nondeterministic) references[$childIndex]) - | .initializeStatesForPartition(partitionIndex); + | .initialize(partitionIndex); """.stripMargin) case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 9ffcf5794094a..4d732445544a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -111,7 +111,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ${ctx.initMutableStates()} } - public void initializeStatesForPartition(int partitionIndex) { + public void initialize(int partitionIndex) { ${ctx.initPartition()} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index de89830509ee9..4ed7355cf044b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -31,7 +31,7 @@ abstract class Predicate { * This is used by non-deterministic expressions to set initial states. * The default implementation does nothing. */ - def initializeStatesForPartition(partitionIndex: Int): Unit = {} + def initialize(partitionIndex: Int): Unit = {} } /** @@ -62,7 +62,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { ${ctx.initMutableStates()} } - public void initializeStatesForPartition(int partitionIndex) { + public void initialize(int partitionIndex) { ${ctx.initPartition()} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 04482b5fd5c17..b1cb6edefb852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -173,7 +173,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ${ctx.initMutableStates()} } - public void initializeStatesForPartition(int partitionIndex) { + public void initialize(int partitionIndex) { ${ctx.initPartition()} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1adad01662431..7e4c9089a2cb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -380,7 +380,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${ctx.initMutableStates()} } - public void initializeStatesForPartition(int partitionIndex) { + public void initialize(int partitionIndex) { ${ctx.initPartition()} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index cd1abc7119492..e9269dc297913 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -71,7 +71,7 @@ package object expressions { * This is used by non-deterministic expressions to set initial states. * The default implementation does nothing. */ - def initializeStatesForPartition(partitionIndex: Int): Unit = {} + def initialize(partitionIndex: Int): Unit = {} } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index a649411edf562..e09029f5aab9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -42,7 +42,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { */ @transient protected var rng: XORShiftRandom = _ - override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = { + override protected def initializeInternal(partitionIndex: Int): Unit = { rng = new XORShiftRandom(seed + partitionIndex) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ca45e213914d3..b6ad5db74e3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1060,7 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { case Project(projectList, LocalRelation(output, data)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) - projection.initializeStatesForPartition(0) + projection.initialize(0) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 90deb6e940508..9ceb709185417 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.initializeStatesForPartition(0) + case n: Nondeterministic => n.initialize(0) case _ => } expression.eval(inputRow) @@ -121,7 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - plan.initializeStatesForPartition(0) + plan.initialize(0) val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { @@ -183,14 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { var plan = generateProject( GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - plan.initializeStatesForPartition(0) + plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - plan.initializeStatesForPartition(0) + plan.initialize(0) actual = FromUnsafeProjection(expression.dataType :: Nil)( plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 50800f649645d..fe5cb8eda824f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -31,21 +31,21 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { // Use an Add to wrap two of them together in case we only initialize the top level expressions. val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = UnsafeProjection.create(Seq(expr)) - instance.initializeStatesForPartition(0) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr)) - instance.initializeStatesForPartition(0) + instance.initialize(0) assert(instance.apply(null).getBoolean(0) === false) } test("GeneratePredicate should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GeneratePredicate.generate(expr) - instance.initializeStatesForPartition(0) + instance.initialize(0) assert(instance.eval(null) === false) } @@ -92,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { */ case class NondeterministicExpression() extends LeafExpression with Nondeterministic with CodegenFallback { - override protected def initializeStatesForPartitionInternal(partitionIndex: Int): Unit = {} + override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = false override def nullable: Boolean = false override def dataType: DataType = BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index b2c0279956800..e485b52b43f76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -73,7 +73,7 @@ case class RowDataSourceScanExec( } else { rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) - proj.initializeStatesForPartition(index) + proj.initialize(index) iter.map(proj) } } @@ -287,7 +287,7 @@ case class FileSourceScanExec( if (needsUnsafeRowConversion) { scan.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) - proj.initializeStatesForPartition(index) + proj.initialize(index) iter.map(proj) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 37f2eabd69a3f..aab087cd98716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -192,7 +192,7 @@ case class RDDScanExec( val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(schema) - proj.initializeStatesForPartition(index) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index a2037cdf7065a..19fbf0c162048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -96,7 +96,7 @@ case class GenerateExec( val numOutputRows = longMetric("numOutputRows") rows.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(output, output) - proj.initializeStatesForPartition(index) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 60f8b78cdbf62..32133f52630cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -73,7 +73,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) child.execute().mapPartitionsWithIndexInternal { (index, iter) => val project = UnsafeProjection.create(projectList, child.output, subexpressionEliminationEnabled) - project.initializeStatesForPartition(index) + project.initialize(index) iter.map(project) } } @@ -208,7 +208,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) val numOutputRows = longMetric("numOutputRows") child.execute().mapPartitionsWithIndexInternal { (index, iter) => val predicate = newPredicate(condition, child.output) - predicate.initializeStatesForPartition(0) + predicate.initialize(0) iter.filter { row => val r = predicate.eval(row) if (r) numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 10b348bba4adc..9028caa446e8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -136,7 +136,7 @@ case class InMemoryTableScanExec( val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) - partitionFilter.initializeStatesForPartition(index) + partitionFilter.initialize(index) // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 6e5262e496456..f526a19876670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -368,7 +368,7 @@ case class BroadcastNestedLoopJoinExec( val numOutputRows = longMetric("numOutputRows") resultRdd.mapPartitionsWithIndexInternal { (index, iter) => val resultProj = genResultProjection - resultProj.initializeStatesForPartition(index) + resultProj.initialize(index) iter.map { r => numOutputRows += 1 resultProj(r) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 53169682d08f4..8341fe2ffd078 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -102,7 +102,7 @@ case class CartesianProductExec( val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { val boundCondition = newPredicate(condition.get, left.output ++ right.output) - boundCondition.initializeStatesForPartition(index) + boundCondition.initialize(index) val joined = new JoinedRow iter.filter { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 6f6c4d1c5c23a..fde3b2a528994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -89,7 +89,7 @@ case class DeserializeToObjectExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) - projection.initializeStatesForPartition(index) + projection.initialize(index) iter.map(projection) } } @@ -127,7 +127,7 @@ case class SerializeFromObjectExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithIndexInternal { (index, iter) => val projection = UnsafeProjection.create(serializer) - projection.initializeStatesForPartition(index) + projection.initialize(index) iter.map(projection) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index fb9ff7b4407db..c80695bd3e0fe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -156,7 +156,7 @@ case class HiveTableScanExec( val outputSchema = schema rdd.mapPartitionsWithIndexInternal { (index, iter) => val proj = UnsafeProjection.create(outputSchema) - proj.initializeStatesForPartition(index) + proj.initialize(index) iter.map { r => numOutputRows += 1 proj(r) From ababaa9dfc7d38a7e276c61d5f35b28ad9718b19 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 1 Nov 2016 22:57:59 -0700 Subject: [PATCH 18/18] add doc --- .../spark/sql/catalyst/expressions/Expression.scala | 9 +++++++++ .../catalyst/expressions/codegen/GeneratePredicate.scala | 4 ++-- .../apache/spark/sql/catalyst/expressions/package.scala | 4 ++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index effb8cc6c68eb..726a231fd814e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -275,6 +275,10 @@ trait Nondeterministic extends Expression { @transient private[this] var initialized = false + /** + * Initializes internal states given the current partition index and mark this as initialized. + * Subclasses should override [[initializeInternal()]]. + */ final def initialize(partitionIndex: Int): Unit = { initializeInternal(partitionIndex) initialized = true @@ -282,6 +286,11 @@ trait Nondeterministic extends Expression { protected def initializeInternal(partitionIndex: Int): Unit + /** + * @inheritdoc + * Throws an exception if [[initialize()]] is not called yet. + * Subclasses should override [[evalInternal()]]. + */ final override def eval(input: InternalRow = null): Any = { require(initialized, s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 4ed7355cf044b..dcd1ed96a298e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -27,8 +27,8 @@ abstract class Predicate { def eval(r: InternalRow): Boolean /** - * Initialize internal states given the current partition index. - * This is used by non-deterministic expressions to set initial states. + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. * The default implementation does nothing. */ def initialize(partitionIndex: Int): Unit = {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index e9269dc297913..1b00c9e79da22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -67,8 +67,8 @@ package object expressions { abstract class Projection extends (InternalRow => InternalRow) { /** - * Initialize internal states given the current partition index. - * This is used by non-deterministic expressions to set initial states. + * Initializes internal states given the current partition index. + * This is used by nondeterministic expressions to set initial states. * The default implementation does nothing. */ def initialize(partitionIndex: Int): Unit = {}