From 1f4e4c812a9dc6d7e35631c1663c1ba6f6d9b721 Mon Sep 17 00:00:00 2001 From: Abhishek Somani Date: Wed, 23 Mar 2022 09:57:28 +0800 Subject: [PATCH] [SPARK-32268][SQL] Row-level Runtime Filtering ### What changes were proposed in this pull request? This PR proposes row-level runtime filters in Spark to reduce intermediate data volume for operators like shuffle, join and aggregate, and hence improve performance. We propose two mechanisms to do this: semi-join filters or bloom filters, and both mechanisms are proposed to co-exist side-by-side behind feature configs. [Design Doc](https://docs.google.com/document/d/16IEuyLeQlubQkH8YuVuXWKo2-grVIoDJqQpHZrE7q04/edit?usp=sharing) with more details. ### Why are the changes needed? With Semi-Join, we see 9 queries improve for the TPC DS 3TB benchmark, and no regressions. With Bloom Filter, we see 10 queries improve for the TPC DS 3TB benchmark, and no regressions. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests Closes #35789 from somani/rf. Lead-authored-by: Abhishek Somani Co-authored-by: Abhishek Somani Co-authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../apache/spark/util/sketch/BloomFilter.java | 7 + .../spark/util/sketch/BloomFilterImpl.java | 5 + .../expressions/BloomFilterMightContain.scala | 113 ++++ .../aggregate/BloomFilterAggregate.scala | 179 +++++++ .../expressions/objects/objects.scala | 2 + .../sql/catalyst/expressions/predicates.scala | 16 + .../expressions/regexpExpressions.scala | 5 +- .../optimizer/InjectRuntimeFilter.scala | 303 +++++++++++ .../sql/catalyst/trees/TreePatterns.scala | 3 + .../apache/spark/sql/internal/SQLConf.scala | 80 +++ .../spark/sql/execution/SparkOptimizer.scala | 2 + .../dynamicpruning/PartitionPruning.scala | 15 - .../sql/BloomFilterAggregateQuerySuite.scala | 215 ++++++++ .../spark/sql/InjectRuntimeFilterSuite.scala | 503 ++++++++++++++++++ 14 files changed, 1432 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java index c53987ecf6e25..2a6e270a91267 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java @@ -163,6 +163,13 @@ int getVersionNumber() { */ public abstract void writeTo(OutputStream out) throws IOException; + /** + * @return the number of set bits in this {@link BloomFilter}. + */ + public long cardinality() { + throw new UnsupportedOperationException("Not implemented"); + } + /** * Reads in a {@link BloomFilter} from an input stream. It is the caller's responsibility to close * the stream. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java index e7766ee903480..ccf1833af9945 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java @@ -207,6 +207,11 @@ public BloomFilter intersectInPlace(BloomFilter other) throws IncompatibleMergeE return this; } + @Override + public long cardinality() { + return this.bits.cardinality(); + } + private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other) throws IncompatibleMergeException { // Duplicates the logic of `isCompatible` here to provide better error message. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala new file mode 100644 index 0000000000000..cf052f865ea90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.ByteArrayInputStream + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal scalar function that returns the membership check result (either true or false) + * for values of `valueExpression` in the Bloom filter represented by `bloomFilterExpression`. + * Not that since the function is "might contain", always returning true regardless is not + * wrong. + * Note that this expression requires that `bloomFilterExpression` is either a constant value or + * an uncorrelated scalar subquery. This is sufficient for the Bloom filter join rewrite. + * + * @param bloomFilterExpression the Binary data of Bloom filter. + * @param valueExpression the Long value to be tested for the membership of `bloomFilterExpression`. + */ +case class BloomFilterMightContain( + bloomFilterExpression: Expression, + valueExpression: Expression) extends BinaryExpression { + + override def nullable: Boolean = true + override def left: Expression = bloomFilterExpression + override def right: Expression = valueExpression + override def prettyName: String = "might_contain" + override def dataType: DataType = BooleanType + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) | + (BinaryType, LongType) => + bloomFilterExpression match { + case e : Expression if e.foldable => TypeCheckResult.TypeCheckSuccess + case subquery : PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " + + "should be either a constant value or a scalar subquery expression") + } + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " + + s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].") + } + } + + override protected def withNewChildrenInternal( + newBloomFilterExpression: Expression, + newValueExpression: Expression): BloomFilterMightContain = + copy(bloomFilterExpression = newBloomFilterExpression, + valueExpression = newValueExpression) + + // The bloom filter created from `bloomFilterExpression`. + @transient private lazy val bloomFilter = { + val bytes = bloomFilterExpression.eval().asInstanceOf[Array[Byte]] + if (bytes == null) null else deserialize(bytes) + } + + override def eval(input: InternalRow): Any = { + if (bloomFilter == null) { + null + } else { + val value = valueExpression.eval(input) + if (value == null) null else bloomFilter.mightContainLong(value.asInstanceOf[Long]) + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (bloomFilter == null) { + ev.copy(isNull = TrueLiteral, value = JavaCode.defaultLiteral(dataType)) + } else { + val bf = ctx.addReferenceObj("bloomFilter", bloomFilter, classOf[BloomFilter].getName) + val valueEval = valueExpression.genCode(ctx) + ev.copy(code = code""" + ${valueEval.code} + boolean ${ev.isNull} = ${valueEval.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $bf.mightContainLong((Long)${valueEval.value}); + }""") + } + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala new file mode 100644 index 0000000000000..c734bca3ef8d0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +/** + * An internal aggregate function that creates a Bloom filter from input values. + * + * @param child Child expression of Long values for creating a Bloom filter. + * @param estimatedNumItemsExpression The number of estimated distinct items (optional). + * @param numBitsExpression The number of bits to use (optional). + */ +case class BloomFilterAggregate( + child: Expression, + estimatedNumItemsExpression: Expression, + numBitsExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[BloomFilter] with TernaryLike[Expression] { + + def this(child: Expression, estimatedNumItemsExpression: Expression, + numBitsExpression: Expression) = { + this(child, estimatedNumItemsExpression, numBitsExpression, 0, 0) + } + + def this(child: Expression, estimatedNumItemsExpression: Expression) = { + this(child, estimatedNumItemsExpression, + // 1 byte per item. + Multiply(estimatedNumItemsExpression, Literal(8L))) + } + + def this(child: Expression) = { + this(child, Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS)), + Literal(SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_NUM_BITS))) + } + + override def checkInputDataTypes(): TypeCheckResult = { + (first.dataType, second.dataType, third.dataType) match { + case (_, NullType, _) | (_, _, NullType) => + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments") + case (LongType, LongType, LongType) => + if (!estimatedNumItemsExpression.foldable) { + TypeCheckFailure("The estimated number of items provided must be a constant literal") + } else if (estimatedNumItems <= 0L) { + TypeCheckFailure("The estimated number of items must be a positive value " + + s" (current value = $estimatedNumItems)") + } else if (!numBitsExpression.foldable) { + TypeCheckFailure("The number of bits provided must be a constant literal") + } else if (numBits <= 0L) { + TypeCheckFailure("The number of bits must be a positive value " + + s" (current value = $numBits)") + } else { + require(estimatedNumItems <= + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) + require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + TypeCheckSuccess + } + case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + + s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " + + s"arguments, but it's [${first.dataType.catalogString}, " + + s"${second.dataType.catalogString}, ${third.dataType.catalogString}]") + } + } + override def nullable: Boolean = true + + override def dataType: DataType = BinaryType + + override def prettyName: String = "bloom_filter_agg" + + // Mark as lazy so that `estimatedNumItems` is not evaluated during tree transformation. + private lazy val estimatedNumItems: Long = + Math.min(estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) + + // Mark as lazy so that `numBits` is not evaluated during tree transformation. + private lazy val numBits: Long = + Math.min(numBitsExpression.eval().asInstanceOf[Number].longValue, + SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + + override def first: Expression = child + + override def second: Expression = estimatedNumItemsExpression + + override def third: Expression = numBitsExpression + + override protected def withNewChildrenInternal( + newChild: Expression, + newEstimatedNumItemsExpression: Expression, + newNumBitsExpression: Expression): BloomFilterAggregate = { + copy(child = newChild, estimatedNumItemsExpression = newEstimatedNumItemsExpression, + numBitsExpression = newNumBitsExpression) + } + + override def createAggregationBuffer(): BloomFilter = { + BloomFilter.create(estimatedNumItems, numBits) + } + + override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { + val value = child.eval(inputRow) + // Ignore null values. + if (value == null) { + return buffer + } + buffer.putLong(value.asInstanceOf[Long]) + buffer + } + + override def merge(buffer: BloomFilter, other: BloomFilter): BloomFilter = { + buffer.mergeInPlace(other) + } + + override def eval(buffer: BloomFilter): Any = { + if (buffer.cardinality() == 0) { + // There's no set bit in the Bloom filter and hence no not-null value is processed. + return null + } + serialize(buffer) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAggregate = + copy(inputAggBufferOffset = newOffset) + + override def serialize(obj: BloomFilter): Array[Byte] = { + BloomFilterAggregate.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): BloomFilter = { + BloomFilterAggregate.deserialize(bytes) + } +} + +object BloomFilterAggregate { + final def serialize(obj: BloomFilter): Array[Byte] = { + // BloomFilterImpl.writeTo() writes 2 integers (version number and num hash functions), hence + // the +8 + val size = (obj.bitSize() / 8) + 8 + require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") + val out = new ByteArrayOutputStream(size.intValue()) + obj.writeTo(out) + out.close() + out.toByteArray + } + + final def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = BloomFilter.readFrom(in) + in.close() + bloomFilter + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6974ada8735c3..2c879beeed623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -360,6 +360,8 @@ case class Invoke( lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) + final override val nodePatterns: Seq[TreePattern] = Seq(INVOKE) + override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable override def children: Seq[Expression] = targetObject +: arguments override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a2fd668f495e0..d16e09c5ed95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -287,6 +287,22 @@ trait PredicateHelper extends AliasHelper with Logging { } } } + + /** + * Returns whether an expression is likely to be selective + */ + def isLikelySelective(e: Expression): Boolean = e match { + case Not(expr) => isLikelySelective(expr) + case And(l, r) => isLikelySelective(l) || isLikelySelective(r) + case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) + case _: StringRegexExpression => true + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case BinaryPredicate(_) => true + case _: MultiLikeBase => true + case _ => false + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 368cbfd6be641..bfaaba514462f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -627,6 +627,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio @transient private var lastReplacementInUTF8: UTF8String = _ // result buffer write by Matcher @transient private lazy val result: StringBuffer = new StringBuffer + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_REPLACE) override def nullSafeEval(s: Any, p: Any, r: Any, i: Any): Any = { if (!p.equals(lastRegex)) { @@ -751,6 +752,8 @@ abstract class RegExpExtractBase // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ + final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) override def first: Expression = subject override def second: Expression = regexp diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala new file mode 100644 index 0000000000000..35d0189f64651 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, Complete} +import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * Insert a filter on one side of the join if the other side has a selective predicate. + * The filter could be an IN subquery (converted to a semi join), a bloom filter, or something + * else in the future. + */ +object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper { + + // Wraps `expr` with a hash function if its byte size is larger than an integer. + private def mayWrapWithHash(expr: Expression): Expression = { + if (expr.dataType.defaultSize > IntegerType.defaultSize) { + new Murmur3Hash(Seq(expr)) + } else { + expr + } + } + + private def injectFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(conf.runtimeFilterBloomFilterEnabled || conf.runtimeFilterSemiJoinReductionEnabled) + if (conf.runtimeFilterBloomFilterEnabled) { + injectBloomFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } else { + injectInSubqueryFilter( + filterApplicationSideExp, + filterApplicationSidePlan, + filterCreationSideExp, + filterCreationSidePlan + ) + } + } + + private def injectBloomFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + // Skip if the filter creation side is too big + if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) { + return filterApplicationSidePlan + } + val rowCount = filterCreationSidePlan.stats.rowCount + val bloomFilterAgg = + if (rowCount.isDefined && rowCount.get.longValue > 0L) { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp)), + Literal(rowCount.get.longValue)) + } else { + new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideExp))) + } + val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None) + val alias = Alias(aggExp, "bloomFilter")() + val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan)) + val bloomFilterSubquery = ScalarSubquery(aggregate, Nil) + val filter = BloomFilterMightContain(bloomFilterSubquery, + new XxHash64(Seq(filterApplicationSideExp))) + Filter(filter, filterApplicationSidePlan) + } + + private def injectInSubqueryFilter( + filterApplicationSideExp: Expression, + filterApplicationSidePlan: LogicalPlan, + filterCreationSideExp: Expression, + filterCreationSidePlan: LogicalPlan): LogicalPlan = { + require(filterApplicationSideExp.dataType == filterCreationSideExp.dataType) + val actualFilterKeyExpr = mayWrapWithHash(filterCreationSideExp) + val alias = Alias(actualFilterKeyExpr, actualFilterKeyExpr.toString)() + val aggregate = Aggregate(Seq(alias), Seq(alias), filterCreationSidePlan) + if (!canBroadcastBySize(aggregate, conf)) { + // Skip the InSubquery filter if the size of `aggregate` is beyond broadcast join threshold, + // i.e., the semi-join will be a shuffled join, which is not worthwhile. + return filterApplicationSidePlan + } + val filter = InSubquery(Seq(mayWrapWithHash(filterApplicationSideExp)), + ListQuery(aggregate, childOutputs = aggregate.output)) + Filter(filter, filterApplicationSidePlan) + } + + /** + * Returns whether the plan is a simple filter over scan and the filter is likely selective + * Also check if the plan only has simple expressions (attribute reference, literals) so that we + * do not add a subquery that might have an expensive computation + */ + private def isSelectiveFilterOverScan(plan: LogicalPlan): Boolean = { + val ret = plan match { + case PhysicalOperation(_, filters, child) if child.isInstanceOf[LeafNode] => + filters.forall(isSimpleExpression) && + filters.exists(isLikelySelective) + case _ => false + } + !plan.isStreaming && ret + } + + private def isSimpleExpression(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE) + } + + private def canFilterLeft(joinType: JoinType): Boolean = joinType match { + case Inner | RightOuter => true + case _ => false + } + + private def canFilterRight(joinType: JoinType): Boolean = joinType match { + case Inner | LeftOuter => true + case _ => false + } + + private def isProbablyShuffleJoin(left: LogicalPlan, + right: LogicalPlan, hint: JoinHint): Boolean = { + !hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) && + !canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf) + } + + private def probablyHasShuffle(plan: LogicalPlan): Boolean = { + plan.collectFirst { + case j@Join(left, right, _, _, hint) + if isProbablyShuffleJoin(left, right, hint) => j + case a: Aggregate => a + }.nonEmpty + } + + // Returns the max scan byte size in the subtree rooted at `filterApplicationSide`. + private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = { + val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES) + filterApplicationSide.collect({ + case leaf: LeafNode => leaf + }).map(scan => { + // DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid + // creating a Bloom filter when the filter application side is very small, so using 0 + // as the byte size when the actual size is unknown can avoid regression by applying BF + // on a small table. + if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes + }).max + } + + // Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a + // Bloom filter; false otherwise. + private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = { + // In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom + // filter to each small task might be more costly than scanning them itself. Thus, we use max + // rather than sum here. + val maxScanSize = maxScanByteSize(filterApplicationSide) + maxScanSize >= + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD) + } + + /** + * Check that: + * - The filterApplicationSideJoinExp can be pushed down through joins and aggregates (ie the + * expression references originate from a single leaf node) + * - The filter creation side has a selective predicate + * - The current join is a shuffle join or a broadcast join that has a shuffle below it + * - The max filterApplicationSide scan size is greater than a configurable threshold + */ + private def filteringHasBenefit( + filterApplicationSide: LogicalPlan, + filterCreationSide: LogicalPlan, + filterApplicationSideExp: Expression, + hint: JoinHint): Boolean = { + findExpressionAndTrackLineageDown(filterApplicationSideExp, + filterApplicationSide).isDefined && isSelectiveFilterOverScan(filterCreationSide) && + (isProbablyShuffleJoin(filterApplicationSide, filterCreationSide, hint) || + probablyHasShuffle(filterApplicationSide)) && + satisfyByteSizeRequirement(filterApplicationSide) + } + + def hasRuntimeFilter(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + if (conf.runtimeFilterBloomFilterEnabled) { + hasBloomFilter(left, right, leftKey, rightKey) + } else { + hasInSubquery(left, right, leftKey, rightKey) + } + } + + // This checks if there is already a DPP filter, as this rule is called just after DPP. + def hasDynamicPruningSubquery( + left: LogicalPlan, + right: LogicalPlan, + leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan), _) => + pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey) + case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _), plan)) => + pruningKey.fastEquals(rightKey) || + hasDynamicPruningSubquery(left, plan, leftKey, rightKey) + case _ => false + } + } + + def hasBloomFilter( + left: LogicalPlan, + right: LogicalPlan, + leftKey: Expression, + rightKey: Expression): Boolean = { + findBloomFilterWithExp(left, leftKey) || findBloomFilterWithExp(right, rightKey) + } + + private def findBloomFilterWithExp(plan: LogicalPlan, key: Expression): Boolean = { + plan.find { + case Filter(condition, _) => + splitConjunctivePredicates(condition).exists { + case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _)) + if valueExpression.fastEquals(key) => true + case _ => false + } + case _ => false + }.isDefined + } + + def hasInSubquery(left: LogicalPlan, right: LogicalPlan, leftKey: Expression, + rightKey: Expression): Boolean = { + (left, right) match { + case (Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _), _) => + key.fastEquals(leftKey) || key.fastEquals(new Murmur3Hash(Seq(leftKey))) + case (_, Filter(InSubquery(Seq(key), + ListQuery(Aggregate(Seq(Alias(_, _)), Seq(Alias(_, _)), _), _, _, _, _)), _)) => + key.fastEquals(rightKey) || key.fastEquals(new Murmur3Hash(Seq(rightKey))) + case _ => false + } + } + + private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = { + var filterCounter = 0 + val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD) + plan transformUp { + case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) => + var newLeft = left + var newRight = right + (leftKeys, rightKeys).zipped.foreach((l, r) => { + // Check if: + // 1. There is already a DPP filter on the key + // 2. There is already a runtime filter (Bloom filter or IN subquery) on the key + // 3. The keys are simple cheap expressions + if (filterCounter < numFilterThreshold && + !hasDynamicPruningSubquery(left, right, l, r) && + !hasRuntimeFilter(newLeft, newRight, l, r) && + isSimpleExpression(l) && isSimpleExpression(r)) { + val oldLeft = newLeft + val oldRight = newRight + if (canFilterLeft(joinType) && filteringHasBenefit(left, right, l, hint)) { + newLeft = injectFilter(l, newLeft, r, right) + } + // Did we actually inject on the left? If not, try on the right + if (newLeft.fastEquals(oldLeft) && canFilterRight(joinType) && + filteringHasBenefit(right, left, r, hint)) { + newRight = injectFilter(r, newRight, l, left) + } + if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) { + filterCounter = filterCounter + 1 + } + } + }) + join.withNewChildren(Seq(newLeft, newRight)) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + case _ if !conf.runtimeFilterSemiJoinReductionEnabled && + !conf.runtimeFilterBloomFilterEnabled => plan + case _ => tryInjectRuntimeFilter(plan) + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index b595966bcc235..3cf45d5f79f00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -54,6 +54,7 @@ object TreePattern extends Enumeration { val IN_SUBQUERY: Value = Value val INSET: Value = Value val INTERSECT: Value = Value + val INVOKE: Value = Value val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value @@ -72,6 +73,8 @@ object TreePattern extends Enumeration { val PIVOT: Value = Value val PLAN_EXPRESSION: Value = Value val PYTHON_UDF: Value = Value + val REGEXP_EXTRACT_FAMILY: Value = Value + val REGEXP_REPLACE: Value = Value val RUNTIME_REPLACEABLE: Value = Value val SCALAR_SUBQUERY: Value = Value val SCALA_UDF: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3314dd1916498..1bba8b6d866a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -341,6 +341,77 @@ object SQLConf { .booleanConf .createWithDefault(true) + val RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED = + buildConf("spark.sql.optimizer.runtimeFilter.semiJoinReduction.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a semi join in the other side to reduce the amount of shuffle data.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_FILTER_NUMBER_THRESHOLD = + buildConf("spark.sql.optimizer.runtimeFilter.number.threshold") + .doc("The total number of injected runtime filters (non-DPP) for a single " + + "query. This is to prevent driver OOMs with too many Bloom filters.") + .version("3.3.0") + .intConf + .checkValue(threshold => threshold >= 0, "The threshold should be >= 0") + .createWithDefault(10) + + val RUNTIME_BLOOM_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.bloomFilter.enabled") + .doc("When true and if one side of a shuffle join has a selective predicate, we attempt " + + "to insert a bloom filter in the other side to reduce the amount of shuffle data.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.creationSideThreshold") + .doc("Size threshold of the bloom filter creation side plan. Estimated size needs to be " + + "under this value to try to inject bloom filter.") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10MB") + + val RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD = + buildConf("spark.sql.optimizer.runtime.bloomFilter.applicationSideScanSizethreshold") + .doc("Byte size threshold of the Bloom filter application side plan's aggregated scan " + + "size. Aggregated scan byte size of the Bloom filter application side needs to be over " + + "this value to inject a bloom filter.") + .version("3.3.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("10GB") + + val RUNTIME_BLOOM_FILTER_EXPECTED_NUM_ITEMS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.expectedNumItems") + .doc("The default number of expected items for the runtime bloomfilter") + .version("3.3.0") + .longConf + .createWithDefault(1000000L) + + val RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumItems") + .doc("The max allowed number of expected items for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(4000000L) + + + val RUNTIME_BLOOM_FILTER_NUM_BITS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.numBits") + .doc("The default number of bits to use for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(8388608L) + + val RUNTIME_BLOOM_FILTER_MAX_NUM_BITS = + buildConf("spark.sql.optimizer.runtime.bloomFilter.maxNumBits") + .doc("The max number of bits to use for the runtime bloom filter") + .version("3.3.0") + .longConf + .createWithDefault(67108864L) + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") @@ -3750,6 +3821,15 @@ class SQLConf extends Serializable with Logging { def dynamicPartitionPruningReuseBroadcastOnly: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY) + def runtimeFilterSemiJoinReductionEnabled: Boolean = + getConf(RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED) + + def runtimeFilterBloomFilterEnabled: Boolean = + getConf(RUNTIME_BLOOM_FILTER_ENABLED) + + def runtimeFilterCreationSideThreshold: Long = + getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 7e8fb4a157262..743cb591b306f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -43,6 +43,8 @@ class SparkOptimizer( Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, PartitionPruning) :+ + Batch("InjectRuntimeFilter", FixedPoint(1), + InjectRuntimeFilter) :+ Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 3b5fc4aea5d8b..89d66034f06cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -194,21 +194,6 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { scanOverhead + cachedOverhead } - /** - * Returns whether an expression is likely to be selective - */ - private def isLikelySelective(e: Expression): Boolean = e match { - case Not(expr) => isLikelySelective(expr) - case And(l, r) => isLikelySelective(l) || isLikelySelective(r) - case Or(l, r) => isLikelySelective(l) && isLikelySelective(r) - case _: StringRegexExpression => true - case _: BinaryComparison => true - case _: In | _: InSet => true - case _: StringPredicate => true - case BinaryPredicate(_) => true - case _: MultiLikeBase => true - case _ => false - } /** * Search a filtering predicate in a given logical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala new file mode 100644 index 0000000000000..025593be4c959 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Query tests for the Bloom filter aggregate and filter function. + */ +class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + val funcId_might_contain = new FunctionIdentifier("might_contain") + + // Register 'bloom_filter_agg' to builtin. + FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Register 'might_contain' to builtin. + FunctionRegistry.builtin.registerFunction(funcId_might_contain, + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + + override def afterAll(): Unit = { + FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg) + FunctionRegistry.builtin.dropFunction(funcId_might_contain) + super.afterAll() + } + + test("Test bloom_filter_agg and might_contain") { + val conf = SQLConf.get + val table = "bloom_filter_test" + for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) { + for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) { + val sqlString = s""" + |SELECT every(might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM $table), + | col)) positive_membership_test, + | every(might_contain( + | (SELECT bloom_filter_agg(col, + | cast($numEstimatedItems as long), + | cast($numBits as long)) + | FROM values (-1L), (100001L), (20000L) as t(col)), + | col)) negative_membership_test + |FROM $table + """.stripMargin + withTempView(table) { + (Seq(Long.MinValue, 0, Long.MaxValue) ++ (1L to 10000L)) + .toDF("col").createOrReplaceTempView(table) + // Validate error messages as well as answers when there's no error. + if (numEstimatedItems <= 0) { + val exception = intercept[AnalysisException] { + spark.sql(sqlString) + } + assert(exception.getMessage.contains( + "The estimated number of items must be a positive value")) + } else if (numBits <= 0) { + val exception = intercept[AnalysisException] { + spark.sql(sqlString) + } + assert(exception.getMessage.contains("The number of bits must be a positive value")) + } else { + checkAnswer(spark.sql(sqlString), Row(true, false)) + } + } + } + } + } + + test("Test that bloom_filter_agg errors out disallowed input value types") { + val exception1 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a) + |FROM values (1.2), (2.5) as t(a)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "Input to function bloom_filter_agg should have been a bigint value")) + + val exception2 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, 2) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "function bloom_filter_agg should have been a bigint value followed with two bigint")) + + val exception3 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, cast(2 as long), 5) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception3.getMessage.contains( + "function bloom_filter_agg should have been a bigint value followed with two bigint")) + + val exception4 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, null, 5) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments")) + + val exception5 = intercept[AnalysisException] { + spark.sql(""" + |SELECT bloom_filter_agg(a, 5, null) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments")) + } + + test("Test that might_contain errors out disallowed input value types") { + val exception1 = intercept[AnalysisException] { + spark.sql("""|SELECT might_contain(1.0, 1L)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "Input to function might_contain should have been binary followed by a value with bigint")) + + val exception2 = intercept[AnalysisException] { + spark.sql("""|SELECT might_contain(NULL, 0.1)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "Input to function might_contain should have been binary followed by a value with bigint")) + } + + test("Test that might_contain errors out non-constant Bloom filter") { + val exception1 = intercept[AnalysisException] { + spark.sql(""" + |SELECT might_contain(cast(a as binary), cast(5 as long)) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception1.getMessage.contains( + "The Bloom filter binary input to might_contain should be either a constant value or " + + "a scalar subquery expression")) + + val exception2 = intercept[AnalysisException] { + spark.sql(""" + |SELECT might_contain((select cast(a as binary)), cast(5 as long)) + |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)""" + .stripMargin) + } + assert(exception2.getMessage.contains( + "The Bloom filter binary input to might_contain should be either a constant value or " + + "a scalar subquery expression")) + } + + test("Test that might_contain can take a constant value input") { + checkAnswer(spark.sql( + """SELECT might_contain( + |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', + |cast(201 as long))""".stripMargin), + Row(false)) + } + + test("Test that bloom_filter_agg produces a NULL with empty input") { + checkAnswer(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1)"""), + Row(null)) + } + + test("Test NULL inputs for might_contain") { + checkAnswer(spark.sql( + s""" + |SELECT might_contain(null, null) both_null, + | might_contain(null, 1L) null_bf, + | might_contain((SELECT bloom_filter_agg(cast(id as long)) from range(1, 10000)), + | null) null_value + """.stripMargin), + Row(null, null, null)) + } + + test("Test that a query with bloom_filter_agg has partial aggregates") { + assert(spark.sql("""SELECT bloom_filter_agg(cast(id as long)) from range(1, 1000000)""") + .queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].inputPlan + .collect({case agg: BaseAggregateExec => agg}).size == 2) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala new file mode 100644 index 0000000000000..a5e27fbfda42a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -0,0 +1,503 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types.{IntegerType, StructType} + +class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSparkSession { + + protected override def beforeAll(): Unit = { + super.beforeAll() + val schema = new StructType().add("a1", IntegerType, nullable = true) + .add("b1", IntegerType, nullable = true) + .add("c1", IntegerType, nullable = true) + .add("d1", IntegerType, nullable = true) + .add("e1", IntegerType, nullable = true) + .add("f1", IntegerType, nullable = true) + + val data1 = Seq(Seq(null, 47, null, 4, 6, 48), + Seq(73, 63, null, 92, null, null), + Seq(76, 10, 74, 98, 37, 5), + Seq(0, 63, null, null, null, null), + Seq(15, 77, null, null, null, null), + Seq(null, 57, 33, 55, null, 58), + Seq(4, 0, 86, null, 96, 14), + Seq(28, 16, 58, null, null, null), + Seq(1, 88, null, 8, null, 79), + Seq(59, null, null, null, 20, 25), + Seq(1, 50, null, 94, 94, null), + Seq(null, null, null, 67, 51, 57), + Seq(77, 50, 8, 90, 16, 21), + Seq(34, 28, null, 5, null, 64), + Seq(null, null, 88, 11, 63, 79), + Seq(92, 94, 23, 1, null, 64), + Seq(57, 56, null, 83, null, null), + Seq(null, 35, 8, 35, null, 70), + Seq(null, 8, null, 35, null, 87), + Seq(9, null, null, 60, null, 5), + Seq(null, 15, 66, null, 83, null)) + val rdd1 = spark.sparkContext.parallelize(data1) + val rddRow1 = rdd1.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow1, schema).write.saveAsTable("bf1") + + val schema2 = new StructType().add("a2", IntegerType, nullable = true) + .add("b2", IntegerType, nullable = true) + .add("c2", IntegerType, nullable = true) + .add("d2", IntegerType, nullable = true) + .add("e2", IntegerType, nullable = true) + .add("f2", IntegerType, nullable = true) + + + val data2 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd2 = spark.sparkContext.parallelize(data2) + val rddRow2 = rdd2.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow2, schema2).write.saveAsTable("bf2") + + val schema3 = new StructType().add("a3", IntegerType, nullable = true) + .add("b3", IntegerType, nullable = true) + .add("c3", IntegerType, nullable = true) + .add("d3", IntegerType, nullable = true) + .add("e3", IntegerType, nullable = true) + .add("f3", IntegerType, nullable = true) + + val data3 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd3 = spark.sparkContext.parallelize(data3) + val rddRow3 = rdd3.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow3, schema3).write.saveAsTable("bf3") + + + val schema4 = new StructType().add("a4", IntegerType, nullable = true) + .add("b4", IntegerType, nullable = true) + .add("c4", IntegerType, nullable = true) + .add("d4", IntegerType, nullable = true) + .add("e4", IntegerType, nullable = true) + .add("f4", IntegerType, nullable = true) + + val data4 = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd4 = spark.sparkContext.parallelize(data4) + val rddRow4 = rdd4.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow4, schema4).write.saveAsTable("bf4") + + val schema5part = new StructType().add("a5", IntegerType, nullable = true) + .add("b5", IntegerType, nullable = true) + .add("c5", IntegerType, nullable = true) + .add("d5", IntegerType, nullable = true) + .add("e5", IntegerType, nullable = true) + .add("f5", IntegerType, nullable = true) + + val data5part = Seq(Seq(67, 17, 45, 91, null, null), + Seq(98, 63, 0, 89, null, 40), + Seq(null, 76, 68, 75, 20, 19), + Seq(8, null, null, null, 78, null), + Seq(48, 62, null, null, 11, 98), + Seq(84, null, 99, 65, 66, 51), + Seq(98, null, null, null, 42, 51), + Seq(10, 3, 29, null, 68, 8), + Seq(85, 36, 41, null, 28, 71), + Seq(89, null, 94, 95, 67, 21), + Seq(44, null, 24, 33, null, 6), + Seq(null, 6, 78, 31, null, 69), + Seq(59, 2, 63, 9, 66, 20), + Seq(5, 23, 10, 86, 68, null), + Seq(null, 63, 99, 55, 9, 65), + Seq(57, 62, 68, 5, null, 0), + Seq(75, null, 15, null, 81, null), + Seq(53, null, 6, 68, 28, 13), + Seq(null, null, null, null, 89, 23), + Seq(36, 73, 40, null, 8, null), + Seq(24, null, null, 40, null, null)) + val rdd5part = spark.sparkContext.parallelize(data5part) + val rddRow5part = rdd5part.map(s => Row.fromSeq(s)) + spark.createDataFrame(rddRow5part, schema5part).write.partitionBy("f5") + .saveAsTable("bf5part") + spark.createDataFrame(rddRow5part, schema5part).filter("a5 > 30") + .write.partitionBy("f5") + .saveAsTable("bf5filtered") + + sql("analyze table bf1 compute statistics for columns a1, b1, c1, d1, e1, f1") + sql("analyze table bf2 compute statistics for columns a2, b2, c2, d2, e2, f2") + sql("analyze table bf3 compute statistics for columns a3, b3, c3, d3, e3, f3") + sql("analyze table bf4 compute statistics for columns a4, b4, c4, d4, e4, f4") + sql("analyze table bf5part compute statistics for columns a5, b5, c5, d5, e5, f5") + sql("analyze table bf5filtered compute statistics for columns a5, b5, c5, d5, e5, f5") + } + + protected override def afterAll(): Unit = try { + sql("DROP TABLE IF EXISTS bf1") + sql("DROP TABLE IF EXISTS bf2") + sql("DROP TABLE IF EXISTS bf3") + sql("DROP TABLE IF EXISTS bf4") + sql("DROP TABLE IF EXISTS bf5part") + sql("DROP TABLE IF EXISTS bf5filtered") + } finally { + super.afterAll() + } + + def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean, + shouldReplace: Boolean): Unit = { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + if (testSemiJoin) { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "true", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + if (shouldReplace) { + val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled)) + val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) + assert(normalizedEnabled != normalizedDisabled) + } else { + comparePlans(planDisabled, planEnabled) + } + } else { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + if (shouldReplace) { + assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled)) + } else { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled)) + } + } + } + } + + def getNumBloomFilters(plan: LogicalPlan): Integer = { + val numBloomFilterAggs = plan.collect { + case Filter(condition, _) => condition.collect { + case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery + => subquery.plan.collect { + case Aggregate(_, aggregateExpressions, _) => + aggregateExpressions.map { + case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, _, _, _), + _) => + assert(bfAgg.estimatedNumItemsExpression.isInstanceOf[Literal]) + assert(bfAgg.numBitsExpression.isInstanceOf[Literal]) + 1 + }.sum + }.sum + }.sum + }.sum + val numMightContains = plan.collect { + case Filter(condition, _) => condition.collect { + case BloomFilterMightContain(_, _) => 1 + }.sum + }.sum + assert(numBloomFilterAggs == numMightContains) + numMightContains + } + + def assertRewroteSemiJoin(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true) + } + + def assertDidNotRewriteSemiJoin(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = false) + } + + def assertRewroteWithBloomFilter(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = true) + } + + def assertDidNotRewriteWithBloomFilter(query: String): Unit = { + checkWithAndWithoutFeatureEnabled(query, testSemiJoin = false, shouldReplace = false) + } + + test("Runtime semi join reduction: simple") { + // Filter creation side is 3409 bytes + // Filter application side scan is 3362 bytes + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62") + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on bf1.c1 = bf2.c2") + } + } + + test("Runtime semi join reduction: two joins") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " + + "and bf3.c3 = bf2.c2 where bf2.a2 = 5") + } + } + + test("Runtime semi join reduction: three joins") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " + + "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5") + } + } + + test("Runtime semi join reduction: simple expressions only") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + val squared = (s: Long) => { + s * s + } + spark.udf.register("square", squared) + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + + "bf1.c1 = bf2.c2 where square(bf2.a2) = 62") + assertDidNotRewriteSemiJoin("select * from bf1 join bf2 on " + + "bf1.c1 = square(bf2.c2) where bf2.a2= 62") + } + } + + test("Runtime bloom filter join: simple") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2") + } + } + + test("Runtime bloom filter join: two filters single join") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.b1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) + } + } + + test("Runtime bloom filter join: test the number of filter threshold") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.b1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + for (numFilterThreshold <- 0 to 3) { + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true", + SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD.key -> numFilterThreshold.toString) { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + if (numFilterThreshold < 3) { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + + numFilterThreshold) + } else { + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 2) + } + } + } + } + + test("Runtime bloom filter join: insert one bloom filter per column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + var planDisabled: LogicalPlan = null + var planEnabled: LogicalPlan = null + var expectedAnswer: Array[Row] = null + + val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 and " + + "bf1.c1 = bf2.b2 where bf2.a2 = 62" + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "false") { + planDisabled = sql(query).queryExecution.optimizedPlan + expectedAnswer = sql(query).collect() + } + + withSQLConf(SQLConf.RUNTIME_FILTER_SEMI_JOIN_REDUCTION_ENABLED.key -> "false", + SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true") { + planEnabled = sql(query).queryExecution.optimizedPlan + checkAnswer(sql(query), expectedAnswer) + } + assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled) + 1) + } + } + + test("Runtime bloom filter join: do not add bloom filter if dpp filter exists " + + "on the same column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertDidNotRewriteWithBloomFilter("select * from bf5part join bf2 on " + + "bf5part.f5 = bf2.c2 where bf2.a2 = 62") + } + } + + test("Runtime bloom filter join: add bloom filter if dpp filter exists on " + + "a different column") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + assertRewroteWithBloomFilter("select * from bf5part join bf2 on " + + "bf5part.c5 = bf2.c2 and bf5part.f5 = bf2.f2 where bf2.a2 = 62") + } + } + + test("Runtime bloom filter join: BF rewrite triggering threshold test") { + // Filter creation side data size is 3409 bytes. On the filter application side, an individual + // scan's byte size is 3362. + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" + ) { + assertRewroteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") + } + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50", + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "50" + ) { + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") + } + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "3000", + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000" + ) { + // Rewrite should not be triggered as the Bloom filter application side scan size is small. + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on bf1.c1 = bf2.c2 " + + "where bf2.a2 = 62") + } + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "32", + SQLConf.RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD.key -> "4000") { + // Test that the max scan size rather than an individual scan size on the filter + // application side matters. `bf5filtered` has 14168 bytes and `bf2` has 3409 bytes. + withSQLConf( + SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "5000") { + assertRewroteWithBloomFilter("select * from " + + "(select * from bf5filtered union all select * from bf2) t " + + "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + } + withSQLConf( + SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "15000") { + assertDidNotRewriteWithBloomFilter("select * from " + + "(select * from bf5filtered union all select * from bf2) t " + + "join bf3 on t.c5 = bf3.c3 where bf3.a3 = 5") + } + } + } + + test("Runtime bloom filter join: simple expressions only") { + withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + val squared = (s: Long) => { + s * s + } + spark.udf.register("square", squared) + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + + "bf1.c1 = bf2.c2 where square(bf2.a2) = 62" ) + assertDidNotRewriteWithBloomFilter("select * from bf1 join bf2 on " + + "bf1.c1 = square(bf2.c2) where bf2.a2 = 62" ) + } + } +}