Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35363][SQL] Refactor sort merge join code-gen be agnostic to join type #32495

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClustered
* Holds common logic for join operators by shuffling two child relations
* using the join keys.
*/
trait ShuffledJoin extends BaseJoinExec {
trait ShuffledJoin extends JoinCodegenSupport {
def isSkewJoin: Boolean

override def nodeName: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ case class SortMergeJoinExec(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport {
isSkewJoin: Boolean = false) extends ShuffledJoin {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down Expand Up @@ -353,12 +353,22 @@ case class SortMergeJoinExec(
}
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike => ((left, leftKeys), (right, rightKeys))
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
}

private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output

override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike]
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
left.execute() :: right.execute() :: Nil
streamedPlan.execute() :: bufferedPlan.execute() :: Nil
}

private def createJoinKey(
Expand Down Expand Up @@ -392,24 +402,24 @@ case class SortMergeJoinExec(
}

/**
* Generate a function to scan both left and right to find a match, returns the term for
* matched one row from left side and buffered rows from right side.
* Generate a function to scan both sides to find a match, returns the term for
* matched one row from streamed side and buffered rows from buffered side.
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
// Inline mutable state since not many join operations in a task
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)

// Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the right key as class members so they could be used in next function call.
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)

// A list to hold all matched rows from right side.
val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the buffered key as class members so they could be used in next function call.
val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)

// A list to hold all matched rows from buffered side.
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName

val spillThreshold = getSpillThreshold
Expand All @@ -418,115 +428,106 @@ case class SortMergeJoinExec(
// Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches",
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
// Copy the streamed keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, streamedKeyVars)

ctx.addNewFunction("findNextInnerJoinRows",
ctx.addNewFunction("findNextJoinRows",
s"""
|private boolean findNextInnerJoinRows(
| scala.collection.Iterator leftIter,
| scala.collection.Iterator rightIter) {
| $leftRow = null;
|private boolean findNextJoinRows(
| scala.collection.Iterator streamedIter,
| scala.collection.Iterator bufferedIter) {
| $streamedRow = null;
| int comp = 0;
| while ($leftRow == null) {
| if (!leftIter.hasNext()) return false;
| $leftRow = (InternalRow) leftIter.next();
| ${leftKeyVars.map(_.code).mkString("\n")}
| if ($leftAnyNull) {
| $leftRow = null;
| while ($streamedRow == null) {
| if (!streamedIter.hasNext()) return false;
| $streamedRow = (InternalRow) streamedIter.next();
| ${streamedKeyVars.map(_.code).mkString("\n")}
| if ($streamedAnyNull) {
| $streamedRow = null;
| continue;
| }
| if (!$matches.isEmpty()) {
| ${genComparison(ctx, leftKeyVars, matchedKeyVars)}
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
| if (comp == 0) {
| return true;
| }
| $matches.clear();
| }
|
| do {
| if ($rightRow == null) {
| if (!rightIter.hasNext()) {
| if ($bufferedRow == null) {
| if (!bufferedIter.hasNext()) {
| ${matchedKeyVars.map(_.code).mkString("\n")}
| return !$matches.isEmpty();
| }
| $rightRow = (InternalRow) rightIter.next();
| ${rightKeyTmpVars.map(_.code).mkString("\n")}
| if ($rightAnyNull) {
| $rightRow = null;
| $bufferedRow = (InternalRow) bufferedIter.next();
| ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
| if ($bufferedAnyNull) {
| $bufferedRow = null;
| continue;
| }
| ${rightKeyVars.map(_.code).mkString("\n")}
| ${bufferedKeyVars.map(_.code).mkString("\n")}
| }
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
| ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
| if (comp > 0) {
| $rightRow = null;
| $bufferedRow = null;
| } else if (comp < 0) {
| if (!$matches.isEmpty()) {
| ${matchedKeyVars.map(_.code).mkString("\n")}
| return true;
| }
| $leftRow = null;
| $streamedRow = null;
| } else {
| $matches.add((UnsafeRow) $rightRow);
| $rightRow = null;
| $matches.add((UnsafeRow) $bufferedRow);
| $bufferedRow = null;
| }
| } while ($leftRow != null);
| } while ($streamedRow != null);
| }
| return false; // unreachable
|}
""".stripMargin, inlineToOuterClass = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple SMJ in one whole-stage, will we have multiple findNextJoinRows methods in the outer class and fail the compilation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is applying to inner join before this PR as well. There won't be an issue and we will never have multiple findNextJoinRows in one class. The reason is with current design, sort merge join will always do code-gen for its children separately, so there won't be two SortMergeJoinExecs code-gen-ed in the same class.

Verified with example query:

val df1 = spark.range(10).select($"id".as("k1"))
val df2 = spark.range(4).select($"id".as("k2"))
val df3 = spark.range(6).select($"id".as("k3"))
df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_outer")
  .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
  .explain("codegen")

Query plan:

*(8) SortMergeJoin [k3#10L], [k1#2L], RightOuter
:- *(5) SortMergeJoin [k3#10L], [k2#6L], LeftOuter
:  :- *(2) Sort [k3#10L ASC NULLS FIRST], false, 0
:  :  +- Exchange hashpartitioning(k3#10L, 5), ENSURE_REQUIREMENTS, [id=#43]
:  :     +- *(1) Project [id#8L AS k3#10L]
:  :        +- *(1) Range (0, 6, step=1, splits=2)
:  +- *(4) Sort [k2#6L ASC NULLS FIRST], false, 0
:     +- Exchange hashpartitioning(k2#6L, 5), ENSURE_REQUIREMENTS, [id=#49]
:        +- *(3) Project [id#4L AS k2#6L]
:           +- *(3) Range (0, 4, step=1, splits=2)
+- *(7) Sort [k1#2L ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(k1#2L, 5), ENSURE_REQUIREMENTS, [id=#58]
      +- *(6) Project [id#0L AS k1#2L]
         +- *(6) Range (0, 10, step=1, splits=2)

All generated code is in https://gist.github.com/c21/873775bcd08583105b289e67221f6e17.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth to add a comment for it, in case we changed this in the future (codegen one side with SMJ together)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we just make it future proof and create a fresh function name here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - thanks for calling it out. I would like to make it future proof by giving it a fresh name. Let me do it in a followup PR. Actually I was wondering the same question as you when implementing and spent some time to figuring it out.


(leftRow, matches)
(streamedRow, matches)
}

/**
* Creates variables and declarations for left part of result row.
* Creates variables and declarations for streamed part of result row.
*
* In order to defer the access after condition and also only access once in the loop,
* the variables should be declared separately from accessing the columns, we can't use the
* codegen of BoundReference here.
*/
private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = {
ctx.INPUT_ROW = leftRow
private def createStreamedVars(
ctx: CodegenContext,
streamedRow: String): (Seq[ExprCode], Seq[String]) = {
ctx.INPUT_ROW = streamedRow
left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
val javaType = CodeGenerator.javaType(a.dataType)
val defaultValue = CodeGenerator.defaultValue(a.dataType)
if (a.nullable) {
val isNull = ctx.freshName("isNull")
val code =
code"""
|$isNull = $leftRow.isNullAt($i);
|$isNull = $streamedRow.isNullAt($i);
|$value = $isNull ? $defaultValue : ($valueCode);
""".stripMargin
val leftVarsDecl =
val streamedVarsDecl =
s"""
|boolean $isNull = false;
|$javaType $value = $defaultValue;
""".stripMargin
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
leftVarsDecl)
streamedVarsDecl)
} else {
val code = code"$value = $valueCode;"
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
val streamedVarsDecl = s"""$javaType $value = $defaultValue;"""
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl)
}
}.unzip
}

/**
* Creates the variables for right part of result row, using BoundReference, since the right
* part are accessed inside the loop.
*/
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = rightRow
right.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
}
}

/**
* Splits variables based on whether it's used by condition or not, returns the code to create
* these variables before the condition and after the condition.
Expand Down Expand Up @@ -554,62 +555,64 @@ case class SortMergeJoinExec(

override def doProduce(ctx: CodegenContext): String = {
// Inline mutable state since not many join operations in a task
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
v => s"$v = inputs[0];", forceInline = true)
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
v => s"$v = inputs[1];", forceInline = true)

val (leftRow, matches) = genScanner(ctx)
val (streamedRow, matches) = genScanner(ctx)

// Create variables for row from both sides.
val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
val bufferedRow = ctx.freshName("bufferedRow")
val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan)

val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = streamedVars ++ bufferedVars

val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition
ctx.currentVars = leftVars ++ rightVars
ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before = s"""
|boolean $loaded = false;
|$leftBefore
|$streamedBefore
""".stripMargin

val checking = s"""
|$rightBefore
|$bufferedBefore
|${cond.code}
|if (${cond.isNull} || !${cond.value}) continue;
|if (!$loaded) {
| $loaded = true;
| $leftAfter
| $streamedAfter
|}
|$rightAfter
|$bufferedAfter
""".stripMargin
(before, checking)
} else {
(evaluateVariables(leftVars), "")
(evaluateVariables(streamedVars), "")
}

val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"

s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| ${leftVarDecl.mkString("\n")}
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) {
| InternalRow $rightRow = (InternalRow) $iterator.next();
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)}
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
Expand Down