Skip to content

Commit

Permalink
[SPARK-34708][SQL] Code-gen for left semi/anti broadcast nested loop …
Browse files Browse the repository at this point in the history
…join (build right side)

### What changes were proposed in this pull request?

This PR is to add code-gen support for left semi / left anti BroadcastNestedLoopJoin (build side is right side). The execution code path for build left side cannot fit into whole stage code-gen framework, so only add the code-gen for build right side here.

Reference: the iterator (non-code-gen) code path is `BroadcastNestedLoopJoinExec.leftExistenceJoin()` with `BuildRight`.

### Why are the changes needed?

Improve query CPU performance.
Tested with a simple query:

```
val N = 20 << 20
val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k2"))
codegenBenchmark("left semi broadcast nested loop join", N) {
  park.range(N).selectExpr(s"id as k1").join(
    dim, col("k1") + 1 <= col("k2"), "left_semi")
}
```

Seeing 5x run time improvement:

```
Running benchmark: left semi broadcast nested loop join
  Running case: left semi broadcast nested loop join codegen off
  Stopped after 2 iterations, 6958 ms
  Running case: left semi broadcast nested loop join codegen on
  Stopped after 5 iterations, 3383 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
left semi broadcast nested loop join:             Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------------------------------------------------------
left semi broadcast nested loop join codegen off           3434           3479          65          6.1         163.7       1.0X
left semi broadcast nested loop join codegen on             672            677           5         31.2          32.1       5.1X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Changed existing unit test in `ExistenceJoinSuite.scala` to cover all code paths:
* left semi/anti + empty right side + empty condition
* left semi/anti + non-empty right side + empty condition
* left semi/anti + right side + non-empty condition

Added unit test in `WholeStageCodegenSuite.scala` to make sure code-gen for broadcast nested loop join is taking effect, and test for multiple join case as well.

Example query:

```
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 <= $"k2", "left_semi").explain("codegen")
```

Example generated code (`bnlj_doConsume_0` method):
This is for left semi join. The generated code for left anti join is mostly to be same as here, except L55 to be `if (bnlj_findMatchedRow_0 == false) {`.
```
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) Project [id#0L AS k1#2L]
+- *(2) BroadcastNestedLoopJoin BuildRight, LeftSemi, ((id#0L + 1) <= k2#6L)
   :- *(2) Range (0, 4, step=1, splits=2)
   +- BroadcastExchange IdentityBroadcastMode, [id=#23]
      +- *(1) Project [id#4L AS k2#6L]
         +- *(1) Range (0, 3, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 031 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     boolean bnlj_findMatchedRow_0 = false;
/* 038 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 039 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 040 */
/* 041 */       long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 042 */
/* 043 */       long bnlj_value_3 = -1L;
/* 044 */
/* 045 */       bnlj_value_3 = bnlj_expr_0_0 + 1L;
/* 046 */
/* 047 */       boolean bnlj_value_2 = false;
/* 048 */       bnlj_value_2 = bnlj_value_3 <= bnlj_value_1;
/* 049 */       if (!(false || !bnlj_value_2))
/* 050 */       {
/* 051 */         bnlj_findMatchedRow_0 = true;
/* 052 */         break;
/* 053 */       }
/* 054 */     }
/* 055 */     if (bnlj_findMatchedRow_0 == true) {
/* 056 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 057 */
/* 058 */       // common sub-expressions
/* 059 */
/* 060 */       range_mutableStateArray_0[3].reset();
/* 061 */
/* 062 */       range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 063 */       append((range_mutableStateArray_0[3].getRow()).copy());
/* 064 */
/* 065 */     }
/* 066 */
/* 067 */   }
/* 068 */
/* 069 */   private void initRange(int idx) {
/* 070 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 071 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 072 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 073 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 074 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 075 */     long partitionEnd;
/* 076 */
/* 077 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 078 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 079 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 080 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 081 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 082 */     } else {
/* 083 */       range_nextIndex_0 = st.longValue();
/* 084 */     }
/* 085 */     range_batchEnd_0 = range_nextIndex_0;
/* 086 */
/* 087 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 088 */     .multiply(step).add(start);
/* 089 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 090 */       partitionEnd = Long.MAX_VALUE;
/* 091 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 092 */       partitionEnd = Long.MIN_VALUE;
/* 093 */     } else {
/* 094 */       partitionEnd = end.longValue();
/* 095 */     }
/* 096 */
/* 097 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 098 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 099 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 100 */     if (range_numElementsTodo_0 < 0) {
/* 101 */       range_numElementsTodo_0 = 0;
/* 102 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 103 */       range_numElementsTodo_0++;
/* 104 */     }
/* 105 */   }
/* 106 */
/* 107 */   protected void processNext() throws java.io.IOException {
/* 108 */     // initialize Range
/* 109 */     if (!range_initRange_0) {
/* 110 */       range_initRange_0 = true;
/* 111 */       initRange(partitionIndex);
/* 112 */     }
/* 113 */
/* 114 */     while (true) {
/* 115 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 116 */         long range_nextBatchTodo_0;
/* 117 */         if (range_numElementsTodo_0 > 1000L) {
/* 118 */           range_nextBatchTodo_0 = 1000L;
/* 119 */           range_numElementsTodo_0 -= 1000L;
/* 120 */         } else {
/* 121 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 122 */           range_numElementsTodo_0 = 0;
/* 123 */           if (range_nextBatchTodo_0 == 0) break;
/* 124 */         }
/* 125 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 126 */       }
/* 127 */
/* 128 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 129 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 130 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 131 */
/* 132 */         bnlj_doConsume_0(range_value_0);
/* 133 */
/* 134 */         if (shouldStop()) {
/* 135 */           range_nextIndex_0 = range_value_0 + 1L;
/* 136 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 137 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 138 */           return;
/* 139 */         }
/* 140 */
/* 141 */       }
/* 142 */       range_nextIndex_0 = range_batchEnd_0;
/* 143 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 144 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 145 */       range_taskContext_0.killTaskIfInterrupted();
/* 146 */     }
/* 147 */   }
/* 148 */
/* 149 */ }
```

Closes #31874 from c21/code-semi-anti.

Authored-by: Cheng Su <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
c21 authored and cloud-fan committed Mar 22, 2021
1 parent c7bf8ad commit f8838fe
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,9 @@ case class BroadcastNestedLoopJoinExec(
}
}

override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike]
override def supportCodegen: Boolean = (joinType, buildSide) match {
case (_: InnerLike, _) | (LeftSemi | LeftAnti, BuildRight) => true
case _ => false
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
Expand All @@ -410,29 +411,33 @@ case class BroadcastNestedLoopJoinExec(
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
joinType match {
case _: InnerLike => codegenInner(ctx, input)
(joinType, buildSide) match {
case (_: InnerLike, _) => codegenInner(ctx, input)
case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true)
case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false)
case _ =>
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin code-gen should not take $joinType as the JoinType")
s"BroadcastNestedLoopJoin code-gen should not take neither $joinType as the JoinType " +
s"nor $buildSide as the BuildSide")
}
}

/**
* Returns the variable name for [[Broadcast]] side.
* Returns a tuple of [[Broadcast]] side and the variable name for it.
*/
private def prepareBroadcast(ctx: CodegenContext): String = {
private def prepareBroadcast(ctx: CodegenContext): (Array[InternalRow], String) = {
// Create a name for broadcast side
val broadcastArray = broadcast.executeBroadcast[Array[InternalRow]]()
val broadcastTerm = ctx.addReferenceObj("broadcastTerm", broadcastArray)

// Inline mutable state since not many join operations in a task
ctx.addMutableState("InternalRow[]", "buildRowArray",
val arrayTerm = ctx.addMutableState("InternalRow[]", "buildRowArray",
v => s"$v = (InternalRow[]) $broadcastTerm.value();", forceInline = true)
(broadcastArray.value, arrayTerm)
}

private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val buildRowArrayTerm = prepareBroadcast(ctx)
val (_, buildRowArrayTerm) = prepareBroadcast(ctx)
val (buildRow, checkCondition, buildVars) = getJoinCondition(ctx, input, streamed, broadcast)

val resultVars = buildSide match {
Expand All @@ -452,4 +457,50 @@ case class BroadcastNestedLoopJoinExec(
|}
""".stripMargin
}

private def codegenLeftExistence(
ctx: CodegenContext,
input: Seq[ExprCode],
exists: Boolean): String = {
val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx)
val numOutput = metricTerm(ctx, "numOutputRows")

if (condition.isEmpty) {
if (buildRowArray.nonEmpty == exists) {
// Return streamed side if join condition is empty and
// 1. build side is non-empty for LeftSemi join
// or
// 2. build side is empty for LeftAnti join.
s"""
|$numOutput.add(1);
|${consume(ctx, input)}
""".stripMargin
} else {
// Return nothing if join condition is empty and
// 1. build side is empty for LeftSemi join
// or
// 2. build side is non-empty for LeftAnti join.
""
}
} else {
val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast)
val foundMatch = ctx.freshName("foundMatch")
val arrayIndex = ctx.freshName("arrayIndex")

s"""
|boolean $foundMatch = false;
|for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
| UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
| $checkCondition {
| $foundMatch = true;
| break;
| }
|}
|if ($foundMatch == $exists) {
| $numOutput.add(1);
| ${consume(ctx, input)}
|}
""".stripMargin
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,42 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
}

test("Left semi/anti BroadcastNestedLoopJoinExec should be included in WholeStageCodegen") {
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
val df3 = spark.range(2).select($"id".as("k3"))

Seq(true, false).foreach { codegenEnabled =>
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {
// test left semi join
val semiJoinDF = df1.join(df2, $"k1" + 1 <= $"k2", "left_semi")
var hasJoinInCodegen = semiJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastNestedLoopJoinExec)) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(semiJoinDF, Seq(Row(0), Row(1)))

// test left anti join
val antiJoinDF = df1.join(df2, $"k1" + 1 <= $"k2", "left_anti")
hasJoinInCodegen = antiJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastNestedLoopJoinExec)) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(antiJoinDF, Seq(Row(2), Row(3)))

// test a combination of left semi and left anti joins
val twoJoinsDF = df1.join(df2, $"k1" < $"k2", "left_semi")
.join(df3, $"k1" > $"k3", "left_anti")
hasJoinInCodegen = twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, BroadcastNestedLoopJoinExec(
_: BroadcastNestedLoopJoinExec, _, _, _, _))) => true
}.size === 1
assert(hasJoinInCodegen == codegenEnabled)
checkAnswer(twoJoinsDF, Seq(Row(0)))
}
}
}

test("Sort should be included in WholeStageCodegen") {
val df = spark.range(3, 0, -1).toDF().sort(col("id"))
val plan = df.queryExecution.executedPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
}

test(s"$testName using BroadcastNestedLoopJoin build right") {
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin build right") { _ =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
Expand Down

0 comments on commit f8838fe

Please sign in to comment.