diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 70a1aabfdae..ee5ff460bb0 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -14,7 +14,7 @@ import pytest from _pytest.mark.structures import ParameterSet -from pyspark.sql.functions import broadcast, col +from pyspark.sql.functions import array_contains, broadcast, col from pyspark.sql.types import * from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture from conftest import is_databricks_runtime, is_emr_runtime @@ -397,6 +397,22 @@ def do_join(spark): return left.join(broadcast(right), left.a > f.log(right.r_a), join_type) assert_gpu_and_cpu_are_equal_collect(do_join) +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn) +@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn) +def test_broadcast_nested_loop_join_with_condition(data_gen, join_type): + def do_join(spark): + left, right = create_df(spark, data_gen, 50, 25) + # AST does not support cast or logarithm yet which is supposed to be extracted into child + # nodes. And this test doesn't cover other join types due to: + # (1) build right are not supported for Right + # (2) FullOuter: currently is not supported + # Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback: + # (1) adapt double to integer since AST current doesn't support it. + # (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types + return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type) + assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True}) + @allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log') @ignore_order(local=True) @pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn) @@ -404,10 +420,24 @@ def do_join(spark): def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type): def do_join(spark): left, right = create_df(spark, data_gen, 50, 25) - # AST does not support cast or logarithm yet + # AST does not support double type which is not split-able into child nodes. return broadcast(left).join(right, left.a > f.log(right.r_a), join_type) assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec') +@ignore_order(local=True) +@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen, + float_gen, double_gen, + string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn) +@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn) +def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type): + arr_gen = ArrayGen(data_gen) + literal = with_cpu_session(lambda spark: gen_scalar(data_gen)) + def do_join(spark): + left, right = create_df(spark, arr_gen, 50, 25) + # Array_contains will be pushed down into project child nodes + return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type))) + assert_gpu_and_cpu_are_equal_collect(do_join) + @ignore_order(local=True) @pytest.mark.parametrize('data_gen', all_gen, ids=idfn) @pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala new file mode 100644 index 00000000000..5062d8e4a99 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression} +import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals} + + +object AstUtil { + + /** + * Check whether it can be split into non-ast sub-expression if needed + * + * @return true when: 1) If all ast-able in expr; 2) all non-ast-able tree nodes don't contain + * attributes from both join sides. In such case, it's not able + * to push down into single child. + */ + def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[ExprId], + right: Seq[ExprId]): Boolean = { + if (!expr.canSelfBeAst) { + // It needs to be split since not ast-able. Check itself and childerns to ensure + // pushing-down can be made, which doesn't need attributions from both sides. + val exprRef = expr.wrapped.asInstanceOf[Expression] + val leftTree = exprRef.references.exists(r => left.contains(r.exprId)) + val rightTree = exprRef.references.exists(r => right.contains(r.exprId)) + // Can't extract a condition involving columns from both sides + !(rightTree && leftTree) + } else { + // Check whether any child contains the case not able to split + expr.childExprs.isEmpty || expr.childExprs.forall( + canExtractNonAstConditionIfNeed(_, left, right)) + } + } + + /** + * Extract non-AST functions from join conditions and update the original join condition. Based + * on the attributes, it decides which side the split condition belongs to. The replaced + * condition is wrapped with GpuAlias with new intermediate attributes. + * + * @param condition to be split if needed + * @param left attributions from left child + * @param right attributions from right child + * @param skipCheck whether skip split-able check + * @return a tuple of [[Expression]] for remained expressions, List of [[NamedExpression]] for + * left child if any, List of [[NamedExpression]] for right child if any + */ + def extractNonAstFromJoinCond(condition: Option[BaseExprMeta[_]], + left: AttributeSeq, right: AttributeSeq, skipCheck: Boolean): + (Option[Expression], List[NamedExpression], List[NamedExpression]) = { + // Choose side with smaller key size. Use expr ID to check the side which project expr + // belonging to. + val (exprIds, isLeft) = if (left.attrs.size < right.attrs.size) { + (left.attrs.map(_.exprId), true) + } else { + (right.attrs.map(_.exprId), false) + } + // List of expression pushing down to left side child + val leftExprs: ListBuffer[NamedExpression] = ListBuffer.empty + // List of expression pushing down to right side child + val rightExprs: ListBuffer[NamedExpression] = ListBuffer.empty + // Substitution map used to replace targeted expressions based on semantic equality + val substitutionMap = mutable.HashMap.empty[GpuExpressionEquals, Expression] + + // 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap + // No need to consider common sub-expressions here since project node will use tiered execution + condition.foreach(c => + if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs.map(_.exprId), right.attrs + .map(_.exprId))) { + splitNonAstInternal(c, exprIds, leftExprs, rightExprs, substitutionMap, isLeft) + }) + + // 2nd step to replace expression pushing down to child plans in depth first fashion + (condition.map( + _.convertToGpu().mapChildren( + GpuEquivalentExpressions.replaceWithSemanticCommonRef(_, + substitutionMap))), leftExprs.toList, rightExprs.toList) + } + + private[this] def splitNonAstInternal(condition: BaseExprMeta[_], childAtt: Seq[ExprId], + left: ListBuffer[NamedExpression], right: ListBuffer[NamedExpression], + substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression], isLeft: Boolean): Unit = { + for (child <- condition.childExprs) { + if (!child.canSelfBeAst) { + val exprRef = child.wrapped.asInstanceOf[Expression] + val gpuProj = child.convertToGpu() + val alias = substitutionMap.get(GpuExpressionEquals(gpuProj)) match { + case Some(_) => None + case None => + if (exprRef.references.exists(r => childAtt.contains(r.exprId)) ^ isLeft) { + val alias = GpuAlias(gpuProj, s"_agpu_non_ast_r_${left.size}")() + right += alias + Some(alias) + } else { + val alias = GpuAlias(gpuProj, s"_agpu_non_ast_l_${left.size}")() + left += alias + Some(alias) + } + } + alias.foreach(a => substitutionMap.put(GpuExpressionEquals(gpuProj), a.toAttribute)) + } else { + splitNonAstInternal(child, childAtt, left, right, substitutionMap, isLeft) + } + } + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index 445a99051b5..3d7c4a1ed67 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1115,6 +1115,15 @@ abstract class BaseExprMeta[INPUT <: Expression]( childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty } + /** + * Check whether this node itself can be converted to AST. It will not recursively check its + * children. It's used to check join condition AST-ability in top-down fashion. + */ + lazy val canSelfBeAst = { + tagForAst() + cannotBeAstReasons.isEmpty + } + final def requireAstForGpu(): Unit = { tagForAst() cannotBeAstReasons.foreach { reason => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala index 257c82eadd1..ed03da6af04 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuEquivalentExpressions.scala @@ -240,6 +240,22 @@ class GpuEquivalentExpressions { } object GpuEquivalentExpressions { + /** + * Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`. + */ + def replaceWithSemanticCommonRef( + expr: Expression, + substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = { + expr match { + case e: AttributeReference => e + case _ => + substitutionMap.get(GpuExpressionEquals(expr)) match { + case Some(attr) => attr + case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap)) + } + } + } + /** * Recursively replaces expression with its proxy expression in `substitutionMap`. */ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala index e20c84b2b88..c9a16003203 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExecBase.scala @@ -28,7 +28,7 @@ import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, IdentityBroadcastMode, UnspecifiedDistribution} import org.apache.spark.sql.execution.SparkPlan @@ -50,6 +50,28 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase( val gpuBuildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide) + private var taggedForAstCheck = false + + // Avoid checking multiple times + private var isAstCond = false + + /** + * Check whether condition can be ast-able. It includes two cases: 1) all join conditions are + * ast-able; 2) join conditions are ast-able after split and push down to child plans. + */ + protected def canJoinCondAstAble(): Boolean = { + if (!taggedForAstCheck) { + val Seq(leftPlan, rightPlan) = childPlans + conditionMeta match { + case Some(e) => isAstCond = AstUtil.canExtractNonAstConditionIfNeed( + e, leftPlan.outputAttributes.map(_.exprId), rightPlan.outputAttributes.map(_.exprId)) + case None => isAstCond = true + } + taggedForAstCheck = true + } + isAstCond + } + override def namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] = JoinTypeChecks.nonEquiJoinMeta(conditionMeta) @@ -60,7 +82,9 @@ abstract class GpuBroadcastNestedLoopJoinMetaBase( join.joinType match { case _: InnerLike => case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - conditionMeta.foreach(requireAstForGpuOn) + // First to check whether can be split if not ast-able. If false, then check requireAst to + // send not-work-on-GPU reason if not replace-able. + conditionMeta.foreach(cond => if (!canJoinCondAstAble()) requireAstForGpuOn(cond)) case _ => willNotWorkOnGpu(s"${join.joinType} currently is not supported") } join.joinType match { @@ -334,7 +358,8 @@ object GpuBroadcastNestedLoopJoinExecBase { val streamBatch = streamSpillable.getBatch val existsCol: ColumnVector = if (builtBatch.numRows == 0) { withResource(Scalar.fromBool(false)) { falseScalar => - GpuColumnVector.from(cudf.ColumnVector.fromScalar(falseScalar, streamBatch.numRows), + GpuColumnVector.from( + cudf.ColumnVector.fromScalar(falseScalar, streamBatch.numRows), BooleanType) } } else { @@ -352,6 +377,21 @@ object GpuBroadcastNestedLoopJoinExecBase { } } + def output(joinType: JoinType, left: Seq[Attribute], right: Seq[Attribute]): Seq[Attribute] = { + joinType match { + case _: InnerLike => left ++ right + case LeftOuter => left ++ right.map(_.withNullability(true)) + case RightOuter => left.map(_.withNullability(true)) ++ right + case FullOuter => + left.map(_.withNullability(true)) ++ right.map(_.withNullability(true)) + case j: ExistenceJoin => left :+ j.exists + case LeftExistence(_) => left + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") + } + } + def divideIntoBatches( rowCounts: RDD[Long], targetSizeBytes: Long, @@ -383,12 +423,16 @@ object GpuBroadcastNestedLoopJoinExecBase { } } +// postBuildCondition is the post-broadcast project condition. It's used to re-construct a tiered +// project to handle pre-built batch. It will be removed after code refactor to decouple +// broadcast and nested loop join. abstract class GpuBroadcastNestedLoopJoinExecBase( left: SparkPlan, right: SparkPlan, joinType: JoinType, gpuBuildSide: GpuBuildSide, condition: Option[Expression], + postBuildCondition: List[NamedExpression], targetSizeBytes: Long) extends ShimBinaryExecNode with GpuExec { import GpuMetric._ @@ -411,7 +455,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( case GpuBuildLeft => (right, left) } - def broadcastExchange: GpuBroadcastExchangeExecBase = buildPlan match { + def broadcastExchange: GpuBroadcastExchangeExecBase = getBroadcastPlan(buildPlan) match { case bqse: BroadcastQueryStageExec if bqse.plan.isInstanceOf[GpuBroadcastExchangeExecBase] => bqse.plan.asInstanceOf[GpuBroadcastExchangeExecBase] case bqse: BroadcastQueryStageExec if bqse.plan.isInstanceOf[ReusedExchangeExec] => @@ -420,6 +464,15 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuBroadcastExchangeExecBase] } + private[this] def getBroadcastPlan(plan: SparkPlan): SparkPlan = { + plan match { + // In case has post broadcast project. It happens when join condition contains non-AST + // expression which results in a project right after broadcast. + case plan: GpuProjectExec => plan.child + case _ => plan + } + } + override def requiredChildDistribution: Seq[Distribution] = gpuBuildSide match { case GpuBuildLeft => BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil @@ -428,23 +481,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( } override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"BroadcastNestedLoopJoin should not take $x as the JoinType") - } + GpuBroadcastNestedLoopJoinExecBase.output(joinType, left.output, right.output) } protected def makeBroadcastBuiltBatch( @@ -468,7 +505,7 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( } } - protected def makeBuiltBatch( + protected def makeBuiltBatchInternal( relation: Any, buildTime: GpuMetric, buildDataSize: GpuMetric): ColumnarBatch = { @@ -477,6 +514,24 @@ abstract class GpuBroadcastNestedLoopJoinExecBase( makeBroadcastBuiltBatch(broadcastRelation, buildTime, buildDataSize) } + final def makeBuiltBatch( + relation: Any, + buildTime: GpuMetric, + buildDataSize: GpuMetric): ColumnarBatch = { + buildPlan match { + case p: GpuProjectExec => + // Need to manually do project columnar execution other than calling child's + // internalDoExecuteColumnar. This is to workaround especial handle to build broadcast + // batch. + val proj = GpuBindReferences.bindGpuReferencesTiered( + postBuildCondition, p.child.output, true) + withResource(makeBuiltBatchInternal(relation, buildTime, buildDataSize)) { + cb => proj.project(cb) + } + case _ => makeBuiltBatchInternal(relation, buildTime, buildDataSize) + } + } + protected def computeBuildRowCount( relation: Any, buildTime: GpuMetric, diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index b2140224b60..523f10f7028 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -36,7 +36,7 @@ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec @@ -58,28 +58,63 @@ class GpuBroadcastNestedLoopJoinMeta( } verifyBuildSideWasReplaced(buildSide) - val condition = conditionMeta.map(_.convertToGpu()) - val isAstCondition = conditionMeta.forall(_.canThisBeAst) - join.joinType match { - case _: InnerLike => - case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case RightOuter if gpuBuildSide == GpuBuildRight => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - // Cannot post-filter these types of joins - assert(isAstCondition, s"Non-AST condition in ${join.joinType}") - case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") - } + // If ast-able, try to split if needed. Otherwise, do post-filter + val isAstCondition = canJoinCondAstAble() + + if(isAstCondition){ + // Try to extract non-ast-able conditions from join conditions + val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond( + conditionMeta, left.output, right.output, true) + + // Reconstruct the childern with wrapped project node if needed. + val leftChild = + if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left + val rightChild = + if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right + val postBuildCondition = + if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output - val joinExec = GpuBroadcastNestedLoopJoinExec( - left, right, - join.joinType, gpuBuildSide, - if (isAstCondition) condition else None, - conf.gpuTargetBatchSizeBytes) - if (isAstCondition) { - joinExec + // TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to + // instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch + // of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build + // a project node to build side batch. + val joinExec = GpuBroadcastNestedLoopJoinExec( + leftChild, rightChild, + join.joinType, gpuBuildSide, + remains, + postBuildCondition, + conf.gpuTargetBatchSizeBytes) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + joinExec + } else { + // Remove the intermediate attributes from left and right side project nodes. Output + // attributes need to be updated based on types + GpuProjectExec( + GpuBroadcastNestedLoopJoinExecBase.output( + join.joinType, left.output, right.output).toList, + joinExec)(false) + } } else { + join.joinType match { + case _: InnerLike => + case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case RightOuter if gpuBuildSide == GpuBuildRight => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + // Cannot post-filter these types of joins + assert(isAstCondition, s"Non-AST condition in ${join.joinType}") + case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") + } + val condition = conditionMeta.map(_.convertToGpu()) + + val joinExec = GpuBroadcastNestedLoopJoinExec( + left, right, + join.joinType, gpuBuildSide, + None, + List.empty, + conf.gpuTargetBatchSizeBytes) + // condition cannot be implemented via AST so fallback to a post-filter if necessary condition.map { // TODO: Restore batch coalescing logic here. @@ -94,13 +129,13 @@ class GpuBroadcastNestedLoopJoinMeta( } } - case class GpuBroadcastNestedLoopJoinExec( left: SparkPlan, right: SparkPlan, joinType: JoinType, gpuBuildSide: GpuBuildSide, condition: Option[Expression], + postBuildCondition: List[NamedExpression], targetSizeBytes: Long) extends GpuBroadcastNestedLoopJoinExecBase( - left, right, joinType, gpuBuildSide, condition, targetSizeBytes + left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes ) diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index 6c04a2aeb57..a5ffebd1aa6 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -28,7 +28,7 @@ import com.nvidia.spark.rapids.Arm.withResource import org.apache.spark.broadcast.Broadcast import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.{CoalescedPartitionSpec, SparkPlan} import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec @@ -53,29 +53,66 @@ class GpuBroadcastNestedLoopJoinMeta( } verifyBuildSideWasReplaced(buildSide) - val condition = conditionMeta.map(_.convertToGpu()) - val isAstCondition = conditionMeta.forall(_.canThisBeAst) - join.joinType match { - case _: InnerLike => - case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case RightOuter if gpuBuildSide == GpuBuildRight => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - // Cannot post-filter these types of joins - assert(isAstCondition, s"Non-AST condition in ${join.joinType}") - case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") - } + // If ast-able, try to split if needed. Otherwise, do post-filter + val isAstCondition = canJoinCondAstAble() - val joinExec = GpuBroadcastNestedLoopJoinExec( - left, right, - join.joinType, gpuBuildSide, - if (isAstCondition) condition else None, - conf.gpuTargetBatchSizeBytes, - join.isExecutorBroadcast) if (isAstCondition) { - joinExec + // Try to extract non-ast-able conditions from join conditions + val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond(conditionMeta, + left.output, right.output, true) + + // Reconstruct the child with wrapped project node if needed. + val leftChild = + if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left + val rightChild = + if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right + val postBuildCondition = + if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output + + // TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to + // instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch + // of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build + // a project node to build side batch. + val joinExec = GpuBroadcastNestedLoopJoinExec( + leftChild, rightChild, + join.joinType, gpuBuildSide, + remains, + postBuildCondition, + conf.gpuTargetBatchSizeBytes, + join.isExecutorBroadcast) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + joinExec + } else { + // Remove the intermediate attributes from left and right side project nodes. Output + // attributes need to be updated based on types + GpuProjectExec( + GpuBroadcastNestedLoopJoinExecBase.output( + join.joinType, left.output, right.output).toList, + joinExec)(false) + } } else { + val condition = conditionMeta.map(_.convertToGpu()) + + join.joinType match { + case _: InnerLike => + case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case RightOuter if gpuBuildSide == GpuBuildRight => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + // Cannot post-filter these types of joins + assert(isAstCondition, s"Non-AST condition in ${join.joinType}") + case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") + } + + val joinExec = GpuBroadcastNestedLoopJoinExec( + left, right, + join.joinType, gpuBuildSide, + None, + List.empty, + conf.gpuTargetBatchSizeBytes, + join.isExecutorBroadcast) + // condition cannot be implemented via AST so fallback to a post-filter if necessary condition.map { // TODO: Restore batch coalescing logic here. @@ -97,9 +134,10 @@ case class GpuBroadcastNestedLoopJoinExec( joinType: JoinType, gpuBuildSide: GpuBuildSide, condition: Option[Expression], + postBuildCondition: List[NamedExpression], targetSizeBytes: Long, executorBroadcast: Boolean) extends GpuBroadcastNestedLoopJoinExecBase( - left, right, joinType, gpuBuildSide, condition, targetSizeBytes + left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes ) { import GpuMetric._ @@ -118,13 +156,31 @@ case class GpuBroadcastNestedLoopJoinExec( executorBroadcast } - def shuffleExchange: GpuShuffleExchangeExec = buildPlan match { - case bqse: ShuffleQueryStageExec if bqse.plan.isInstanceOf[GpuShuffleExchangeExec] => - bqse.plan.asInstanceOf[GpuShuffleExchangeExec] - case bqse: ShuffleQueryStageExec if bqse.plan.isInstanceOf[ReusedExchangeExec] => - bqse.plan.asInstanceOf[ReusedExchangeExec].child.asInstanceOf[GpuShuffleExchangeExec] - case gpu: GpuShuffleExchangeExec => gpu - case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuShuffleExchangeExec] + def shuffleExchange: GpuShuffleExchangeExec = { + def from(p: ShuffleQueryStageExec): GpuShuffleExchangeExec = p.plan match { + case g: GpuShuffleExchangeExec => g + case ReusedExchangeExec(_, g: GpuShuffleExchangeExec) => g + case _ => throw new IllegalStateException(s"cannot locate GPU shuffle in $p") + } + + getBroadcastPlan(buildPlan) match { + case gpu: GpuShuffleExchangeExec => gpu + case sqse: ShuffleQueryStageExec => from(sqse) + case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuShuffleExchangeExec] + case GpuShuffleCoalesceExec(GpuCustomShuffleReaderExec(sqse: ShuffleQueryStageExec, _), _) => + from(sqse) + case GpuShuffleCoalesceExec(sqse: ShuffleQueryStageExec, _) => from(sqse) + case GpuCustomShuffleReaderExec(sqse: ShuffleQueryStageExec, _) => from(sqse) + } + } + + private[this] def getBroadcastPlan(plan: SparkPlan): SparkPlan = { + plan match { + // In case has post broadcast project. It happens when join condition contains non-AST + // expression which results in a project right after broadcast. + case plan: GpuProjectExec => plan.child + case _ => plan + } } override def getBroadcastRelation(): Any = { @@ -149,8 +205,8 @@ case class GpuBroadcastNestedLoopJoinExec( val targetSize = RapidsConf.GPU_BATCH_SIZE_BYTES.get(conf) val metricsMap = allMetrics withResource(new NvtxWithMetrics("build join table", NvtxColor.GREEN, buildTime)) { _ => - val builtBatch = GpuExecutorBroadcastHelper.getExecutorBroadcastBatch(rdd, buildPlan.schema, - buildPlan.output, metricsMap, targetSize) + val builtBatch = GpuExecutorBroadcastHelper.getExecutorBroadcastBatch(rdd, getBroadcastPlan + (buildPlan).schema, getBroadcastPlan(buildPlan).output, metricsMap, targetSize) buildDataSize += GpuColumnVector.getTotalDeviceMemoryUsed(builtBatch) builtBatch } @@ -166,7 +222,7 @@ case class GpuBroadcastNestedLoopJoinExec( } } - override def makeBuiltBatch( + override def makeBuiltBatchInternal( relation: Any, buildTime: GpuMetric, buildDataSize: GpuMetric): ColumnarBatch = { diff --git a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index 7499802e170..c2e5afdb5b9 100644 --- a/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/spark340/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec @@ -45,28 +45,64 @@ class GpuBroadcastNestedLoopJoinMeta( } verifyBuildSideWasReplaced(buildSide) - val condition = conditionMeta.map(_.convertToGpu()) - val isAstCondition = conditionMeta.forall(_.canThisBeAst) - join.joinType match { - case _: InnerLike => - case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case RightOuter if gpuBuildSide == GpuBuildRight => - throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") - case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - // Cannot post-filter these types of joins - assert(isAstCondition, s"Non-AST condition in ${join.joinType}") - case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") - } + // If ast-able, try to split if needed. Otherwise, do post-filter + val isAstCondition = canJoinCondAstAble() - val joinExec = GpuBroadcastNestedLoopJoinExec( - left, right, - join.joinType, gpuBuildSide, - if (isAstCondition) condition else None, - conf.gpuTargetBatchSizeBytes) if (isAstCondition) { - joinExec + // Try to extract non-ast-able conditions from join conditions + val (remains, leftExpr, rightExpr) = AstUtil.extractNonAstFromJoinCond(conditionMeta, + left.output, right.output, true) + + // Reconstruct the child with wrapped project node if needed. + val leftChild = + if (!leftExpr.isEmpty) GpuProjectExec(leftExpr ++ left.output, left)(true) else left + val rightChild = + if (!rightExpr.isEmpty) GpuProjectExec(rightExpr ++ right.output, right)(true) else right + val postBuildCondition = + if (gpuBuildSide == GpuBuildLeft) leftExpr ++ left.output else rightExpr ++ right.output + + // TODO: a code refactor is needed to skip passing in postBuildCondition as a parameter to + // instantiate GpuBroadcastNestedLoopJoinExec. This is because currently output columnar batch + // of broadcast side is handled inside GpuBroadcastNestedLoopJoinExec. Have to manually build + // a project node to build side batch. + val joinExec = GpuBroadcastNestedLoopJoinExec( + leftChild, rightChild, + join.joinType, gpuBuildSide, + remains, + postBuildCondition, + conf.gpuTargetBatchSizeBytes) + if (leftExpr.isEmpty && rightExpr.isEmpty) { + joinExec + } else { + // Remove the intermediate attributes from left and right side project nodes. Output + // attributes need to be updated based on types + GpuProjectExec( + GpuBroadcastNestedLoopJoinExecBase.output( + join.joinType, left.output, right.output).toList, + joinExec)(false) + } } else { + val condition = conditionMeta.map(_.convertToGpu()) + + join.joinType match { + case _: InnerLike => + case LeftOuter | LeftSemi | LeftAnti if gpuBuildSide == GpuBuildLeft => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case RightOuter if gpuBuildSide == GpuBuildRight => + throw new IllegalStateException(s"Unsupported build side for join type ${join.joinType}") + case LeftOuter | RightOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + // Cannot post-filter these types of joins + assert(isAstCondition, s"Non-AST condition in ${join.joinType}") + case _ => throw new IllegalStateException(s"Unsupported join type ${join.joinType}") + } + + val joinExec = GpuBroadcastNestedLoopJoinExec( + left, right, + join.joinType, gpuBuildSide, + None, + List.empty, + conf.gpuTargetBatchSizeBytes) + // condition cannot be implemented via AST so fallback to a post-filter if necessary condition.map { // TODO: Restore batch coalescing logic here. @@ -88,6 +124,7 @@ case class GpuBroadcastNestedLoopJoinExec( joinType: JoinType, gpuBuildSide: GpuBuildSide, condition: Option[Expression], + postBuildCondition: List[NamedExpression], targetSizeBytes: Long) extends GpuBroadcastNestedLoopJoinExecBase( - left, right, joinType, gpuBuildSide, condition, targetSizeBytes + left, right, joinType, gpuBuildSide, condition, postBuildCondition, targetSizeBytes ) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AstUtilSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AstUtilSuite.scala new file mode 100644 index 00000000000..52825910dda --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AstUtilSuite.scala @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.rapids.{GpuGreaterThan, GpuLength, GpuStringTrim} +import org.apache.spark.sql.types.StringType + + +class AstUtilSuite extends GpuUnitTests { + + private[this] def testSingleNode(containsNonAstAble: Boolean, crossMultiChildPlan: Boolean) + : Boolean = { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val r1 = AttributeReference("r1", StringType)() + val r2 = AttributeReference("r2", StringType)() + + val expr = mock(classOf[Expression]) + val attributeSet = if (crossMultiChildPlan) { + AttributeSet(Seq(l1, r1)) + } else { + AttributeSet(Seq(l1, l2)) + } + when(expr.references).thenReturn(attributeSet) + + val exprMeta = mock(classOf[BaseExprMeta[Expression]]) + when(exprMeta.childExprs).thenReturn(Seq.empty) + when(exprMeta.canSelfBeAst).thenReturn(!containsNonAstAble) + when(exprMeta.wrapped).thenReturn(expr) + + AstUtil.canExtractNonAstConditionIfNeed(exprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2).map + (_.exprId)) + } + + private[this] def testMultiNodes(containsNonAstAble: Boolean, crossMultiChildPlan: Boolean) + : Boolean = { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val r1 = AttributeReference("r1", StringType)() + val r2 = AttributeReference("r2", StringType)() + + val attributeSet = if (crossMultiChildPlan) { + AttributeSet(Seq(l1, r1)) + } else { + AttributeSet(Seq(l1, l2)) + } + val leftExprMeta = buildLeaf(attributeSet, containsNonAstAble) + + val rightExprMeta = mock(classOf[BaseExprMeta[Expression]]) + when(rightExprMeta.childExprs).thenReturn(Seq.empty) + when(rightExprMeta.canSelfBeAst).thenReturn(true) + + val rootExprMeta = mock(classOf[BaseExprMeta[Expression]]) + when(rootExprMeta.childExprs).thenReturn(Seq(leftExprMeta, rightExprMeta)) + + when(rootExprMeta.canSelfBeAst).thenReturn(true) + + AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2) + .map(_.exprId)) + } + + private[this] def buildLeaf(attributeSet: AttributeSet, containsNonAstAble: Boolean) + : BaseExprMeta[Expression] = { + val expr = mock(classOf[Expression]) + val exprMeta = mock(classOf[BaseExprMeta[Expression]]) + when(exprMeta.childExprs).thenReturn(Seq.empty) + when(exprMeta.canSelfBeAst).thenReturn(!containsNonAstAble) + + when(expr.references).thenReturn(attributeSet) + when(exprMeta.wrapped).thenReturn(expr) + exprMeta + } + + private[this] def testMultiNodes2(containsNonAstAble: Boolean, crossMultiChildPlan: Boolean) + : Boolean = { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val r1 = AttributeReference("r1", StringType)() + val r2 = AttributeReference("r2", StringType)() + + // Build left + val leftAttrSet = if (crossMultiChildPlan) { + AttributeSet(Seq(l1, r1)) + } else { + AttributeSet(Seq(l1, l2)) + } + val leftExprMeta = buildLeaf(leftAttrSet, containsNonAstAble) + + // Build right + val rightAttrSet = if (!crossMultiChildPlan) { + AttributeSet(Seq(l1, r1)) + } else { + AttributeSet(Seq(l1, l2)) + } + val rightExprMeta = buildLeaf(rightAttrSet, containsNonAstAble) + + // Build root + val rootExprMeta = mock(classOf[BaseExprMeta[Expression]]) + when(rootExprMeta.childExprs).thenReturn(Seq(leftExprMeta, rightExprMeta)) + when(rootExprMeta.canSelfBeAst).thenReturn(true) + + AstUtil.canExtractNonAstConditionIfNeed(rootExprMeta, Seq(l1, l2).map(_.exprId), Seq(r1, r2) + .map(_.exprId)) + } + + test("Single node tree for ast split if needed") { + for ((canAstSplitIfNeeded, containsNonAstAble, crossMultiChildPlan) <- Seq( + (false, true, true), (true, true, false), (true, false, true), (true, false, false))) { + assertResult( + canAstSplitIfNeeded)(testSingleNode(containsNonAstAble, crossMultiChildPlan)) + } + } + + test("Multi-nodes tree for ast split if needed") { + for ((canAstSplitIfNeeded, containsNonAstAble, crossMultiChildPlan) <- Seq( + (false, true, true), (true, true, false), (true, false, true), (true, false, false))) { + assertResult( + canAstSplitIfNeeded)(testMultiNodes(containsNonAstAble, crossMultiChildPlan)) + } + } + + test("Multi-nodes tree for ast split if needed complex case") { + for ((canAstSplitIfNeeded, containsNonAstAble, crossMultiChildPlan) <- Seq( + (false, true, true), (false, true, false), (true, false, true), (true, false, false))) { + assertResult( + canAstSplitIfNeeded)(testMultiNodes2(containsNonAstAble, crossMultiChildPlan)) + } + } + + // ======== test cases for AST split ======== + // Build a simple tree: string_trim(a:string). string_trim's AST-ability is controlled by + // astAble for different test purposes + private[this] def buildTree1(attSet: AttributeReference, astAble: Boolean) + : BaseExprMeta[Expression] = { + val expr = GpuStringTrim(attSet) + val rootMeta = mock(classOf[BaseExprMeta[Expression]]) + when(rootMeta.childExprs).thenReturn(Seq.empty) + when(rootMeta.canSelfBeAst).thenReturn(astAble) + when(rootMeta.convertToGpu).thenReturn(expr) + when(rootMeta.wrapped).thenReturn(expr) + rootMeta + } + + // Build a simple tree: length(string_trim(a:string)). string_length's AST-ability is + // controlled by astAble for different test purposes + private[this] def buildTree2(attSet: AttributeReference, astAble: Boolean) + : BaseExprMeta[Expression] = { + val expr = GpuLength(GpuStringTrim(attSet)) + val rootMeta = mock(classOf[BaseExprMeta[Expression]]) + val childExprs = Seq(buildTree1(attSet, astAble)) + when(rootMeta.childExprs).thenReturn(childExprs) + when(rootMeta.canSelfBeAst).thenReturn(astAble) + when(rootMeta.convertToGpu).thenReturn(expr) + when(rootMeta.wrapped).thenReturn(expr) + rootMeta + } + + // Build a complex tree: + // length(trim(a1:string)) > length(trim(a2:string)) + private[this] def buildTree3(attSet1: AttributeReference, attSet2: AttributeReference, + astAble: Boolean) + : BaseExprMeta[Expression] = { + val expr = GpuGreaterThan(GpuLength(GpuStringTrim(attSet1)), GpuLength(GpuStringTrim(attSet2))) + val rootMeta = mock(classOf[BaseExprMeta[Expression]]) + val childExprs = Seq(buildTree2(attSet1, astAble), buildTree2(attSet2, astAble)) + when(rootMeta.childExprs).thenReturn(childExprs) + when(rootMeta.canSelfBeAst).thenReturn(true) + when(rootMeta.convertToGpu).thenReturn(expr) + when(rootMeta.wrapped).thenReturn(expr) + rootMeta + } + + test("Non-Ast-able tree should not split"){ + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val (e, l, r) = + AstUtil.extractNonAstFromJoinCond(Some(buildTree1(l1, false)), Seq(l1), Seq(l2), false) + assertResult(true)(l.isEmpty) + assertResult(true)(r.isEmpty) + assertResult(true)(e.get.isInstanceOf[GpuStringTrim]) + } + + test("Tree of single ast-able node should not split") { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val (e, l, r) = + AstUtil.extractNonAstFromJoinCond(Some(buildTree1(l1, true)), Seq(l1), Seq(l2), false) + assertResult(true)(l.isEmpty) + assertResult(true)(r.isEmpty) + assertResult(true)(e.get.isInstanceOf[GpuStringTrim]) + } + + test("Project pushing down to same child") { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val (e, l, r) = + AstUtil.extractNonAstFromJoinCond(Some(buildTree3(l1, l1, false)), Seq(l1), Seq(l2), false) + assertResult(true)(l.size == 1) + assertResult(true)(l.exists(checkEquals(_, GpuLength(GpuStringTrim(l1))))) + assertResult(true)(r.isEmpty) + assertResult(true)(l.exists(checkEquals(_, GpuLength(GpuStringTrim(l1))))) + assertResult(true)(checkEquals(e.get, GpuGreaterThan(l(0).toAttribute, l(0).toAttribute))) + } + + private def realExpr(expr: Expression): Expression = expr match { + case e: GpuAlias => e.child + case _ => expr + } + + private def checkEquals(expr: Expression, other: Expression): Boolean = { + realExpr(expr).semanticEquals(realExpr(other)) + } + + test("Project pushing down to different childern") { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val (e, l, r) = + AstUtil.extractNonAstFromJoinCond(Some(buildTree3(l1, l2, false)), Seq(l1), Seq(l2), false) + assertResult(true)(l.size == 1) + assertResult(true)(l.exists(checkEquals(_, GpuLength(GpuStringTrim(l1))))) + assertResult(true)(r.size == 1) + assertResult(true)(r.exists(checkEquals(_, GpuLength(GpuStringTrim(l2))))) + assertResult(true)( + checkEquals(e.get, GpuGreaterThan(l(0).toAttribute, r(0).toAttribute))) + } + + test("A tree with multiple ast-able childern should not split") { + val l1 = AttributeReference("l1", StringType)() + val l2 = AttributeReference("l2", StringType)() + val (e, l, r) = + AstUtil.extractNonAstFromJoinCond(Some(buildTree3(l1, l2, true)), Seq(l1), Seq(l2), false) + assertResult(true)(l.size == 0) + assertResult(true)(r.size == 0) + assertResult(true)(checkEquals(e.get, + GpuGreaterThan(GpuLength(GpuStringTrim(l1)), GpuLength(GpuStringTrim(l2))))) + } +}