From bd5050b74dc3e6d436952498e4af41598bf7a0ff Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 17 Jan 2024 00:37:56 +0000 Subject: [PATCH 1/4] Remove `condition` from exec code --- .../main/scala/execution/PITJoinExec.scala | 61 +------------------ scala/src/main/scala/execution/Patterns.scala | 2 +- 2 files changed, 2 insertions(+), 61 deletions(-) diff --git a/scala/src/main/scala/execution/PITJoinExec.scala b/scala/src/main/scala/execution/PITJoinExec.scala index fc750f3..872e6cb 100644 --- a/scala/src/main/scala/execution/PITJoinExec.scala +++ b/scala/src/main/scala/execution/PITJoinExec.scala @@ -553,32 +553,6 @@ protected[pit] case class PITJoinExec( (leftRow, matched) } - /** Splits variables based on whether it's used by condition or not, returns - * the code to create these variables before the condition and after the - * condition. - * - * Only a few columns are used by condition, then we can skip the accessing - * of those columns that are not used by condition also filtered out by - * condition. - */ - private def splitVarsByCondition( - attributes: Seq[Attribute], - variables: Seq[ExprCode] - ): (String, String) = { - if (condition.isDefined) { - val condRefs = condition.get.references - val (used, notUsed) = - attributes.zip(variables).partition { case (a, ev) => - condRefs.contains(a) - } - val beforeCond = evaluateVariables(used.map(_._2)) - val afterCond = evaluateVariables(notUsed.map(_._2)) - (beforeCond, afterCond) - } else { - (evaluateVariables(variables), "") - } - } - override def needCopyResult: Boolean = true override protected def doProduce(ctx: CodegenContext): String = { @@ -605,38 +579,7 @@ protected[pit] case class PITJoinExec( val numOutput = metricTerm(ctx, "numOutputRows") - val (beforeLoop, condCheck) = if (condition.isDefined) { - // Split the code of creating variables based on whether it's used by condition or not. - val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = - splitVarsByCondition(right.output, rightVars) - // Generate code for condition - ctx.currentVars = leftVars ++ rightVars - val cond = - BindReferences.bindReference(condition.get, output).genCode(ctx) - // evaluate the columns those used by condition before loop - val before = - s""" - |boolean $loaded = false; - |$leftBefore - """.stripMargin - - val checking = - s""" - |$rightBefore - |${cond.code} - |if (${cond.isNull}|| !${cond.value}) continue; - |if (!$loaded) { - | $loaded = true; - | $leftAfter - |} - |$rightAfter - """.stripMargin - (before, checking) - } else { - (evaluateVariables(leftVars), "") - } + val beforeLoop = evaluateVariables(leftVars) val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();" @@ -647,7 +590,6 @@ protected[pit] case class PITJoinExec( | ${leftVarDecl.mkString("\n")} | ${beforeLoop.trim} | InternalRow $rightRow = (InternalRow) $matched; - | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)}; | if (shouldStop()) return; @@ -659,7 +601,6 @@ protected[pit] case class PITJoinExec( | ${leftVarDecl.mkString("\n")} | ${beforeLoop.trim} | InternalRow $rightRow = (InternalRow) $matched; - | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} | if (shouldStop()) return; diff --git a/scala/src/main/scala/execution/Patterns.scala b/scala/src/main/scala/execution/Patterns.scala index 15cb092..42879db 100644 --- a/scala/src/main/scala/execution/Patterns.scala +++ b/scala/src/main/scala/execution/Patterns.scala @@ -135,7 +135,7 @@ object PITJoinExtractEquality extends ExtractEqualityKeys { rightPitKey, leftEquiKeys, rightEquiKeys, - otherPredicates.reduceOption(And), + None, join.returnNulls, join.tolerance, join.left, From 0a8322fe31bfa8bfcf539c7933a9d19c8d3a477e Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 17 Jan 2024 00:39:28 +0000 Subject: [PATCH 2/4] Fail physical planning if there are non-equi conditions --- scala/src/main/scala/execution/Patterns.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/scala/src/main/scala/execution/Patterns.scala b/scala/src/main/scala/execution/Patterns.scala index 42879db..462ec6d 100644 --- a/scala/src/main/scala/execution/Patterns.scala +++ b/scala/src/main/scala/execution/Patterns.scala @@ -108,13 +108,11 @@ object PITJoinExtractEquality extends ExtractEqualityKeys { // These need to be sortable in order to make the algorithm work optimized val equiJoinKeys = getEquiJoinKeys(predicates, join.left, join.right) - val otherPredicates = predicates.filterNot { - case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => - false - case Equality(l, r) => - canEvaluate(l, join.left) && canEvaluate(r, join.right) || - canEvaluate(l, join.right) && canEvaluate(r, join.left) - case _ => false + if (predicates.length != equiJoinKeys.length) { + logDebug( + s"Could not extract all equi-join keys from join condition: ${join.condition}" + ) + return None } val leftPitKey = if (canEvaluate(join.pitCondition.children.head, join.left)) From 0966a285f05bbe3c39c11d867b149d385ed005c5 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 17 Jan 2024 11:43:49 +0000 Subject: [PATCH 3/4] Add test --- scala/src/test/scala/EarlyStopMergeTests.scala | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/scala/src/test/scala/EarlyStopMergeTests.scala b/scala/src/test/scala/EarlyStopMergeTests.scala index 816023f..47fbc56 100644 --- a/scala/src/test/scala/EarlyStopMergeTests.scala +++ b/scala/src/test/scala/EarlyStopMergeTests.scala @@ -352,7 +352,7 @@ class EarlyStopMergeTests extends AnyFlatSpec with SparkSessionTestWrapper { smallData.fg1_with_key_nulls, smallData.fg3_with_key_nulls, smallData.PIT_1_3_WITH_KEY_NULLS, - // It could be argued the correct schema would be `smallData.PIT_2_schema`, + // It could be argued the correct schema would be `smallData.PIT_2_schema`, // i.e. with join key columns made non-nullable. However, the normal spark inner // join does not do this, so we don't either. smallData.PIT_2_NULLABLE_KEYS_schema, @@ -431,4 +431,19 @@ class EarlyStopMergeTests extends AnyFlatSpec with SparkSessionTestWrapper { testBothCodegenAndInterpreted("left_join_three_dataframes") { testJoiningThreeDataframes("left", smallData.PIT_3_OUTER_schema) } + + testBothCodegenAndInterpreted("fail_during_planning_for_non_equi_condition") { + val fg1 = smallData.fg1 + val fg2 = smallData.fg2 + + val pitJoin = + fg1.join( + fg2, + pit(fg1("ts"), fg2("ts"), lit(0)) && fg1("id") === fg2("id") && fg1("value") > fg2("value"), + "inner" + ) + intercept[IllegalArgumentException] { + pitJoin.explain() + } + } } From bf3c5367ed28a89d35ff7d1770dbea6ebae9c660 Mon Sep 17 00:00:00 2001 From: Thomas Newton Date: Wed, 17 Jan 2024 11:45:26 +0000 Subject: [PATCH 4/4] Throw an exception --- scala/src/main/scala/execution/Patterns.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scala/src/main/scala/execution/Patterns.scala b/scala/src/main/scala/execution/Patterns.scala index 462ec6d..9e62d2e 100644 --- a/scala/src/main/scala/execution/Patterns.scala +++ b/scala/src/main/scala/execution/Patterns.scala @@ -109,10 +109,9 @@ object PITJoinExtractEquality extends ExtractEqualityKeys { val equiJoinKeys = getEquiJoinKeys(predicates, join.left, join.right) if (predicates.length != equiJoinKeys.length) { - logDebug( - s"Could not extract all equi-join keys from join condition: ${join.condition}" + throw new IllegalArgumentException( + "Besides the PIT key, only equi-conditions are supported for PIT joins" ) - return None } val leftPitKey = if (canEvaluate(join.pitCondition.children.head, join.left))