Skip to content

Commit

Permalink
Fully remove non equi non pit condition (#10)
Browse files Browse the repository at this point in the history
* Remove `condition` from exec code

* Fail physical planning if there are non-equi conditions

* Add test

* Throw an exception
  • Loading branch information
Tom-Newton authored Jan 17, 2024
1 parent 832dbe8 commit 67a454e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 69 deletions.
61 changes: 1 addition & 60 deletions scala/src/main/scala/execution/PITJoinExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -553,32 +553,6 @@ protected[pit] case class PITJoinExec(
(leftRow, matched)
}

/** Splits variables based on whether it's used by condition or not, returns
* the code to create these variables before the condition and after the
* condition.
*
* Only a few columns are used by condition, then we can skip the accessing
* of those columns that are not used by condition also filtered out by
* condition.
*/
private def splitVarsByCondition(
attributes: Seq[Attribute],
variables: Seq[ExprCode]
): (String, String) = {
if (condition.isDefined) {
val condRefs = condition.get.references
val (used, notUsed) =
attributes.zip(variables).partition { case (a, ev) =>
condRefs.contains(a)
}
val beforeCond = evaluateVariables(used.map(_._2))
val afterCond = evaluateVariables(notUsed.map(_._2))
(beforeCond, afterCond)
} else {
(evaluateVariables(variables), "")
}
}

override def needCopyResult: Boolean = true

override protected def doProduce(ctx: CodegenContext): String = {
Expand All @@ -605,38 +579,7 @@ protected[pit] case class PITJoinExec(

val numOutput = metricTerm(ctx, "numOutputRows")

val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
val (rightBefore, rightAfter) =
splitVarsByCondition(right.output, rightVars)
// Generate code for condition
ctx.currentVars = leftVars ++ rightVars
val cond =
BindReferences.bindReference(condition.get, output).genCode(ctx)
// evaluate the columns those used by condition before loop
val before =
s"""
|boolean $loaded = false;
|$leftBefore
""".stripMargin

val checking =
s"""
|$rightBefore
|${cond.code}
|if (${cond.isNull}|| !${cond.value}) continue;
|if (!$loaded) {
| $loaded = true;
| $leftAfter
|}
|$rightAfter
""".stripMargin
(before, checking)
} else {
(evaluateVariables(leftVars), "")
}
val beforeLoop = evaluateVariables(leftVars)

val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();"
Expand All @@ -647,7 +590,6 @@ protected[pit] case class PITJoinExec(
| ${leftVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| InternalRow $rightRow = (InternalRow) $matched;
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)};
| if (shouldStop()) return;
Expand All @@ -659,7 +601,6 @@ protected[pit] case class PITJoinExec(
| ${leftVarDecl.mkString("\n")}
| ${beforeLoop.trim}
| InternalRow $rightRow = (InternalRow) $matched;
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)}
| if (shouldStop()) return;
Expand Down
13 changes: 5 additions & 8 deletions scala/src/main/scala/execution/Patterns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,10 @@ object PITJoinExtractEquality extends ExtractEqualityKeys {
// These need to be sortable in order to make the algorithm work optimized
val equiJoinKeys = getEquiJoinKeys(predicates, join.left, join.right)

val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty =>
false
case Equality(l, r) =>
canEvaluate(l, join.left) && canEvaluate(r, join.right) ||
canEvaluate(l, join.right) && canEvaluate(r, join.left)
case _ => false
if (predicates.length != equiJoinKeys.length) {
throw new IllegalArgumentException(
"Besides the PIT key, only equi-conditions are supported for PIT joins"
)
}
val leftPitKey =
if (canEvaluate(join.pitCondition.children.head, join.left))
Expand All @@ -135,7 +132,7 @@ object PITJoinExtractEquality extends ExtractEqualityKeys {
rightPitKey,
leftEquiKeys,
rightEquiKeys,
otherPredicates.reduceOption(And),
None,
join.returnNulls,
join.tolerance,
join.left,
Expand Down
17 changes: 16 additions & 1 deletion scala/src/test/scala/EarlyStopMergeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class EarlyStopMergeTests extends AnyFlatSpec with SparkSessionTestWrapper {
smallData.fg1_with_key_nulls,
smallData.fg3_with_key_nulls,
smallData.PIT_1_3_WITH_KEY_NULLS,
// It could be argued the correct schema would be `smallData.PIT_2_schema`,
// It could be argued the correct schema would be `smallData.PIT_2_schema`,
// i.e. with join key columns made non-nullable. However, the normal spark inner
// join does not do this, so we don't either.
smallData.PIT_2_NULLABLE_KEYS_schema,
Expand Down Expand Up @@ -431,4 +431,19 @@ class EarlyStopMergeTests extends AnyFlatSpec with SparkSessionTestWrapper {
testBothCodegenAndInterpreted("left_join_three_dataframes") {
testJoiningThreeDataframes("left", smallData.PIT_3_OUTER_schema)
}

testBothCodegenAndInterpreted("fail_during_planning_for_non_equi_condition") {
val fg1 = smallData.fg1
val fg2 = smallData.fg2

val pitJoin =
fg1.join(
fg2,
pit(fg1("ts"), fg2("ts"), lit(0)) && fg1("id") === fg2("id") && fg1("value") > fg2("value"),
"inner"
)
intercept[IllegalArgumentException] {
pitJoin.explain()
}
}
}

0 comments on commit 67a454e

Please sign in to comment.