Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(1.5.x Cherry-pick) Spark 3.4: Fix system function pushdown in CoW row-level commands (#10119) #10171

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@ import org.apache.iceberg.spark.functions.SparkFunctions
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
import org.apache.spark.sql.catalyst.expressions.BinaryComparison
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.InSet
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
import org.apache.spark.sql.catalyst.trees.TreePattern.IN
import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
Expand All @@ -40,21 +48,36 @@ import org.apache.spark.sql.types.StructType
object ReplaceStaticInvoke extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
case filter @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))
plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) {
case join @ Join(_, _, _, Some(cond), _) =>
replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond)))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}
case filter @ Filter(cond, _) =>
replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond))
}

private def replaceStaticInvoke[T <: LogicalPlan](
node: T,
condition: Expression,
copy: Expression => T): T = {
val newCondition = replaceStaticInvoke(condition)
if (newCondition fastEquals condition) node else copy(newCondition)
}

private def replaceStaticInvoke(condition: Expression): Expression = {
condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
case in @ In(value: StaticInvoke, _) if canReplace(value) =>
in.copy(value = replaceStaticInvoke(value))

if (newCondition fastEquals condition) {
filter
} else {
filter.copy(condition = newCondition)
}
case in @ InSet(value: StaticInvoke, _) if canReplace(value) =>
in.copy(child = replaceStaticInvoke(value))

case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}
}

private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@

import static scala.collection.JavaConverters.seqAsJavaListConverter;

import java.util.Collection;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.execution.CommandResultExec;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec;
import scala.PartialFunction;
import scala.collection.Seq;

public class SparkPlanUtil {
Expand Down Expand Up @@ -53,6 +58,49 @@ private static SparkPlan actualPlan(SparkPlan plan) {
}
}

public static List<Expression> collectExprs(
SparkPlan sparkPlan, Predicate<Expression> predicate) {
Seq<List<Expression>> seq =
SPARK_HELPER.collect(
sparkPlan,
new PartialFunction<SparkPlan, List<Expression>>() {
@Override
public List<Expression> apply(SparkPlan plan) {
List<Expression> exprs = Lists.newArrayList();

for (Expression expr : toJavaList(plan.expressions())) {
exprs.addAll(collectExprs(expr, predicate));
}

return exprs;
}

@Override
public boolean isDefinedAt(SparkPlan plan) {
return true;
}
});
return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList());
}

private static List<Expression> collectExprs(
Expression expression, Predicate<Expression> predicate) {
Seq<Expression> seq =
expression.collect(
new PartialFunction<Expression, Expression>() {
@Override
public Expression apply(Expression expr) {
return expr;
}

@Override
public boolean isDefinedAt(Expression expr) {
return predicate.test(expr);
}
});
return toJavaList(seq);
}

private static <T> List<T> toJavaList(Seq<T> seq) {
return seqAsJavaListConverter(seq).asJava();
}
Expand Down
Loading