Skip to content

Commit

Permalink
[SPARK-14393][SQL] values generated by non-deterministic functions sh…
Browse files Browse the repository at this point in the history
…ouldn't change after coalesce or union

## What changes were proposed in this pull request?

When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following:
- The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations.

However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column.

See the unit tests below or JIRA for examples.

This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization.
## How was this patch tested?

Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...)

cc: rxin davies

Author: Xiangrui Meng <[email protected]>

Closes #15567 from mengxr/SPARK-14393.
  • Loading branch information
mengxr authored and rxin committed Nov 2, 2016
1 parent 742e0fe commit 02f2031
Show file tree
Hide file tree
Showing 32 changed files with 231 additions and 78 deletions.
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
}

/**
* [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
* performance API to be used carefully only if we are sure that the RDD elements are
* [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
* It is a performance API to be used carefully only if we are sure that the RDD elements are
* serializable and don't require closure cleaning.
*
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
* which should be `false` unless this is a pair RDD and the input function doesn't modify
* the keys.
*/
private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
preservesPartitioning)
}

/**
* [performance] Spark's internal mapPartitions method that skips closure cleaning.
*/
private[spark] def mapPartitionsInternal[U: ClassTag](
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,28 @@ trait Nondeterministic extends Expression {
final override def deterministic: Boolean = false
final override def foldable: Boolean = false

@transient
private[this] var initialized = false

final def setInitialValues(): Unit = {
initInternal()
/**
* Initializes internal states given the current partition index and mark this as initialized.
* Subclasses should override [[initializeInternal()]].
*/
final def initialize(partitionIndex: Int): Unit = {
initializeInternal(partitionIndex)
initialized = true
}

protected def initInternal(): Unit
protected def initializeInternal(partitionIndex: Int): Unit

/**
* @inheritdoc
* Throws an exception if [[initialize()]] is not called yet.
* Subclasses should override [[evalInternal()]].
*/
final override def eval(input: InternalRow = null): Any = {
require(initialized, "nondeterministic expression should be initialized before evaluate")
require(initialized,
s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
evalInternal(input)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {

override def prettyName: String = "input_file_name"

override protected def initInternal(): Unit = {}
override protected def initializeInternal(partitionIndex: Int): Unit = {}

override protected def evalInternal(input: InternalRow): UTF8String = {
InputFileNameHolder.getInputFileName()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis

@transient private[this] var partitionMask: Long = _

override protected def initInternal(): Unit = {
override protected def initializeInternal(partitionIndex: Int): Unit = {
count = 0L
partitionMask = TaskContext.getPartitionId().toLong << 33
partitionMask = partitionIndex.toLong << 33
}

override def nullable: Boolean = false
Expand All @@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.freshName("count")
val partitionMaskTerm = ctx.freshName("partitionMask")
ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")

ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@ import org.apache.spark.sql.types.{DataType, StructType}

/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
*
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
})
override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null
Expand All @@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
/**
* A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
* expressions.
*
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
Expand All @@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu

private[this] val buffer = new Array[Any](expressions.size)

expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
})
override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}

private[this] val exprArray = expressions.toArray
private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}

/**
* Expression that returns the current partition id of the Spark task.
* Expression that returns the current partition id.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns the current partition id of the Spark task",
usage = "_FUNC_() - Returns the current partition id",
extended = "> SELECT _FUNC_();\n 0")
case class SparkPartitionID() extends LeafExpression with Nondeterministic {

Expand All @@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {

override val prettyName = "SPARK_PARTITION_ID"

override protected def initInternal(): Unit = {
partitionId = TaskContext.getPartitionId()
override protected def initializeInternal(partitionIndex: Int): Unit = {
partitionId = partitionIndex
}

override protected def evalInternal(input: InternalRow): Int = partitionId

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = ctx.freshName("partitionId")
ctx.addMutableState(ctx.JAVA_INT, idTerm,
s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ class CodegenContext {
splitExpressions(initCodes, "init", Nil)
}

/**
* Code statements to initialize states that depend on the partition index.
* An integer `partitionIndex` will be made available within the scope.
*/
val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty

def addPartitionInitializationStatement(statement: String): Unit = {
partitionInitializationStatements += statement
}

def initPartition(): String = {
partitionInitializationStatements.mkString("\n")
}

/**
* Holding all the functions those will be added into generated class.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No
trait CodegenFallback extends Expression {

protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
}

// LeafNode does not need `input`
val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
val idx = ctx.references.length
ctx.references += this
var childIndex = idx
this.foreach {
case n: Nondeterministic =>
// This might add the current expression twice, but it won't hurt.
ctx.references += n
childIndex += 1
ctx.addPartitionInitializationStatement(
s"""
|((Nondeterministic) references[$childIndex])
| .initialize(partitionIndex);
""".stripMargin)
case _ =>
}
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
if (nullable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
${ctx.initMutableStates()}
}

public void initialize(int partitionIndex) {
${ctx.initPartition()}
}

${ctx.declareAddedFunctions()}

public ${classOf[BaseMutableProjection].getName} target(InternalRow row) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._
*/
abstract class Predicate {
def eval(r: InternalRow): Boolean

/**
* Initializes internal states given the current partition index.
* This is used by nondeterministic expressions to set initial states.
* The default implementation does nothing.
*/
def initialize(partitionIndex: Int): Unit = {}
}

/**
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
*/
object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] {
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {

protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)

protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
BindReferences.bindReference(in, inputSchema)

protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
protected def create(predicate: Expression): Predicate = {
val ctx = newCodeGenContext()
val eval = predicate.genCode(ctx)

Expand All @@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
${ctx.initMutableStates()}
}

public void initialize(int partitionIndex) {
${ctx.initPartition()}
}

${ctx.declareAddedFunctions()}

public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
Expand All @@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")

val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
(r: InternalRow) => p.eval(r)
CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
${ctx.initMutableStates()}
}

public void initialize(int partitionIndex) {
${ctx.initPartition()}
}

${ctx.declareAddedFunctions()}

public java.lang.Object apply(java.lang.Object _i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${ctx.initMutableStates()}
}

public void initialize(int partitionIndex) {
${ctx.initPartition()}
}

${ctx.declareAddedFunctions()}

// Scala.Function1 need this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ package object expressions {
* column of the new row. If the schema of the input row is specified, then the given expression
* will be bound to that schema.
*/
abstract class Projection extends (InternalRow => InternalRow)
abstract class Projection extends (InternalRow => InternalRow) {

/**
* Initializes internal states given the current partition index.
* This is used by nondeterministic expressions to set initial states.
* The default implementation does nothing.
*/
def initialize(partitionIndex: Int): Unit = {}
}

/**
* Converts a [[InternalRow]] to another Row given a sequence of expression that define each
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ object InterpretedPredicate {
create(BindReferences.bindReference(expression, inputSchema))

def create(expression: Expression): (InternalRow => Boolean) = {
expression.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
}
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
*/
@transient protected var rng: XORShiftRandom = _

override protected def initInternal(): Unit = {
rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}

override def nullable: Boolean = false
Expand All @@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
ctx.addMutableState(className, rngTerm, "")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
}
Expand All @@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val rngTerm = ctx.freshName("rng")
val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
ctx.addMutableState(className, rngTerm, "")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
case Project(projectList, LocalRelation(output, data))
if !projectList.exists(hasUnevaluableExpr) =>
val projection = new InterpretedProjection(projectList, output)
projection.initialize(0)
LocalRelation(projectList.map(_.toAttribute), data.map(projection))
}

Expand Down
Loading

0 comments on commit 02f2031

Please sign in to comment.