Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support split BroadcastNestedLoopJoin condition for AST and non-AST [databricks] #9702

Merged
merged 4 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from _pytest.mark.structures import ParameterSet
from pyspark.sql.functions import broadcast, col
from pyspark.sql.functions import array_contains, broadcast, col
from pyspark.sql.types import *
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_cpu_and_gpu_are_equal_collect_with_capture
from conftest import is_databricks_runtime, is_emr_runtime
Expand Down Expand Up @@ -397,17 +397,47 @@ def do_join(spark):
return left.join(broadcast(right), left.a > f.log(right.r_a), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Cross', 'Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet which is supposed to be extracted into child
# nodes. And this test doesn't cover other join types due to:
# (1) build right are not supported for Right
# (2) FullOuter: currently is not supported
# Those fallback reasons are not due to AST. Additionally, this test case changes test_broadcast_nested_loop_join_with_condition_fallback:
# (1) adapt double to integer since AST current doesn't support it.
# (2) switch to right side build to pass checks of 'Left', 'LeftSemi', 'LeftAnti' join types
return left.join(broadcast(right), f.round(left.a).cast('integer') > f.round(f.log(right.r_a).cast('integer')), join_type)
assert_gpu_and_cpu_are_equal_collect(do_join, conf={"spark.rapids.sql.castFloatToIntegralTypes.enabled": True})

@allow_non_gpu('BroadcastExchangeExec', 'BroadcastNestedLoopJoinExec', 'Cast', 'GreaterThan', 'Log')
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [IntegerGen(), LongGen(), pytest.param(FloatGen(), marks=[incompat]), pytest.param(DoubleGen(), marks=[incompat])], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_condition_fallback(data_gen, join_type):
def do_join(spark):
left, right = create_df(spark, data_gen, 50, 25)
# AST does not support cast or logarithm yet
# AST does not support double type which is not split-able into child nodes.
return broadcast(left).join(right, left.a > f.log(right.r_a), join_type)
assert_gpu_fallback_collect(do_join, 'BroadcastNestedLoopJoinExec')

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [byte_gen, short_gen, int_gen, long_gen,
float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen], ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'Right', 'FullOuter', 'LeftSemi', 'LeftAnti'], ids=idfn)
def test_broadcast_nested_loop_join_with_array_contains(data_gen, join_type):
arr_gen = ArrayGen(data_gen)
literal = with_cpu_session(lambda spark: gen_scalar(data_gen))
def do_join(spark):
left, right = create_df(spark, arr_gen, 50, 25)
# Array_contains will be pushed down into project child nodes
return broadcast(left).join(right, array_contains(left.a, literal.cast(data_gen.data_type)) < array_contains(right.r_a, literal.cast(data_gen.data_type)))
assert_gpu_and_cpu_are_equal_collect(do_join)

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
@pytest.mark.parametrize('join_type', ['Left', 'LeftSemi', 'LeftAnti'], ids=idfn)
Expand Down
122 changes: 122 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.catalyst.expressions.{AttributeSeq, Expression, ExprId, NamedExpression}
import org.apache.spark.sql.rapids.catalyst.expressions.{GpuEquivalentExpressions, GpuExpressionEquals}


object AstUtil {

/**
* Check whether it can be split into non-ast sub-expression if needed
*
* @return true when: 1) If all ast-able in expr; 2) all non-ast-able tree nodes don't contain
* attributes from both join sides. In such case, it's not able
* to push down into single child.
*/
def canExtractNonAstConditionIfNeed(expr: BaseExprMeta[_], left: Seq[ExprId],
right: Seq[ExprId]): Boolean = {
if (!expr.canSelfBeAst) {
// It needs to be split since not ast-able. Check itself and childerns to ensure
// pushing-down can be made, which doesn't need attributions from both sides.
val exprRef = expr.wrapped.asInstanceOf[Expression]
val leftTree = exprRef.references.exists(r => left.contains(r.exprId))
val rightTree = exprRef.references.exists(r => right.contains(r.exprId))
// Can't extract a condition involving columns from both sides
!(rightTree && leftTree)
} else {
// Check whether any child contains the case not able to split
expr.childExprs.isEmpty || expr.childExprs.forall(
canExtractNonAstConditionIfNeed(_, left, right))
}
}

/**
* Extract non-AST functions from join conditions and update the original join condition. Based
* on the attributes, it decides which side the split condition belongs to. The replaced
* condition is wrapped with GpuAlias with new intermediate attributes.
*
* @param condition to be split if needed
* @param left attributions from left child
* @param right attributions from right child
* @param skipCheck whether skip split-able check
* @return a tuple of [[Expression]] for remained expressions, List of [[NamedExpression]] for
* left child if any, List of [[NamedExpression]] for right child if any
*/
def extractNonAstFromJoinCond(condition: Option[BaseExprMeta[_]],
left: AttributeSeq, right: AttributeSeq, skipCheck: Boolean):
(Option[Expression], List[NamedExpression], List[NamedExpression]) = {
// Choose side with smaller key size. Use expr ID to check the side which project expr
// belonging to.
val (exprIds, isLeft) = if (left.attrs.size < right.attrs.size) {
(left.attrs.map(_.exprId), true)
} else {
(right.attrs.map(_.exprId), false)
}
// List of expression pushing down to left side child
val leftExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// List of expression pushing down to right side child
val rightExprs: ListBuffer[NamedExpression] = ListBuffer.empty
// Substitution map used to replace targeted expressions based on semantic equality
val substitutionMap = mutable.HashMap.empty[GpuExpressionEquals, Expression]

// 1st step to construct 1) left expr list; 2) right expr list; 3) substitutionMap
// No need to consider common sub-expressions here since project node will use tiered execution
condition.foreach(c =>
if (skipCheck || canExtractNonAstConditionIfNeed(c, left.attrs.map(_.exprId), right.attrs
.map(_.exprId))) {
splitNonAstInternal(c, exprIds, leftExprs, rightExprs, substitutionMap, isLeft)
})

// 2nd step to replace expression pushing down to child plans in depth first fashion
(condition.map(
_.convertToGpu().mapChildren(
GpuEquivalentExpressions.replaceWithSemanticCommonRef(_,
substitutionMap))), leftExprs.toList, rightExprs.toList)
}

private[this] def splitNonAstInternal(condition: BaseExprMeta[_], childAtt: Seq[ExprId],
left: ListBuffer[NamedExpression], right: ListBuffer[NamedExpression],
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression], isLeft: Boolean): Unit = {
for (child <- condition.childExprs) {
if (!child.canSelfBeAst) {
val exprRef = child.wrapped.asInstanceOf[Expression]
val gpuProj = child.convertToGpu()
val alias = substitutionMap.get(GpuExpressionEquals(gpuProj)) match {
case Some(_) => None
case None =>
if (exprRef.references.exists(r => childAtt.contains(r.exprId)) ^ isLeft) {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_r_${left.size}")()
right += alias
Some(alias)
} else {
val alias = GpuAlias(gpuProj, s"_agpu_non_ast_l_${left.size}")()
left += alias
Some(alias)
}
}
alias.foreach(a => substitutionMap.put(GpuExpressionEquals(gpuProj), a.toAttribute))
} else {
splitNonAstInternal(child, childAtt, left, right, substitutionMap, isLeft)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,15 @@ abstract class BaseExprMeta[INPUT <: Expression](
childExprs.forall(_.canThisBeAst) && cannotBeAstReasons.isEmpty
}

/**
* Check whether this node itself can be converted to AST. It will not recursively check its
* children. It's used to check join condition AST-ability in top-down fashion.
*/
lazy val canSelfBeAst = {
tagForAst()
cannotBeAstReasons.isEmpty
}

final def requireAstForGpu(): Unit = {
tagForAst()
cannotBeAstReasons.foreach { reason =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,22 @@ class GpuEquivalentExpressions {
}

object GpuEquivalentExpressions {
/**
* Recursively replaces semantic equal expression with its proxy expression in `substitutionMap`.
*/
def replaceWithSemanticCommonRef(
expr: Expression,
substitutionMap: mutable.HashMap[GpuExpressionEquals, Expression]): Expression = {
expr match {
case e: AttributeReference => e
case _ =>
substitutionMap.get(GpuExpressionEquals(expr)) match {
case Some(attr) => attr
case None => expr.mapChildren(replaceWithSemanticCommonRef(_, substitutionMap))
}
}
}

/**
* Recursively replaces expression with its proxy expression in `substitutionMap`.
*/
Expand Down
Loading