Skip to content

Commit

Permalink
Spark 3.5: Fix system function pushdown in CoW row-level commands (#9873
Browse files Browse the repository at this point in the history
) (#10170)

Co-authored-by: Anton Okolnychyi <[email protected]>
  • Loading branch information
amogh-jahagirdar and aokolnychyi authored Apr 18, 2024
1 parent 5d73e6b commit 3ed0597
Show file tree
Hide file tree
Showing 11 changed files with 485 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@ import org.apache.iceberg.spark.functions.SparkFunctions
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
import org.apache.spark.sql.catalyst.expressions.BinaryComparison
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.InSet
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
import org.apache.spark.sql.catalyst.trees.TreePattern.IN
import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
Expand All @@ -40,21 +48,39 @@ import org.apache.spark.sql.types.StructType
object ReplaceStaticInvoke extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
case filter @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}

if (newCondition fastEquals condition) {
filter
} else {
filter.copy(condition = newCondition)
}
plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) {
case replace @ ReplaceData(_, cond, _, _, _, _) =>
replaceStaticInvoke(replace, cond, newCond => replace.copy(condition = newCond))

case join @ Join(_, _, _, Some(cond), _) =>
replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond)))

case filter @ Filter(cond, _) =>
replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond))
}

private def replaceStaticInvoke[T <: LogicalPlan](
node: T,
condition: Expression,
copy: Expression => T): T = {
val newCondition = replaceStaticInvoke(condition)
if (newCondition fastEquals condition) node else copy(newCondition)
}

private def replaceStaticInvoke(condition: Expression): Expression = {
condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
case in @ In(value: StaticInvoke, _) if canReplace(value) =>
in.copy(value = replaceStaticInvoke(value))

case in @ InSet(value: StaticInvoke, _) if canReplace(value) =>
in.copy(child = replaceStaticInvoke(value))

case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}
}

private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@

import static scala.collection.JavaConverters.seqAsJavaListConverter;

import java.util.Collection;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.execution.CommandResultExec;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec;
import scala.PartialFunction;
import scala.collection.Seq;

public class SparkPlanUtil {
Expand Down Expand Up @@ -53,6 +58,49 @@ private static SparkPlan actualPlan(SparkPlan plan) {
}
}

public static List<Expression> collectExprs(
SparkPlan sparkPlan, Predicate<Expression> predicate) {
Seq<List<Expression>> seq =
SPARK_HELPER.collect(
sparkPlan,
new PartialFunction<SparkPlan, List<Expression>>() {
@Override
public List<Expression> apply(SparkPlan plan) {
List<Expression> exprs = Lists.newArrayList();

for (Expression expr : toJavaList(plan.expressions())) {
exprs.addAll(collectExprs(expr, predicate));
}

return exprs;
}

@Override
public boolean isDefinedAt(SparkPlan plan) {
return true;
}
});
return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList());
}

private static List<Expression> collectExprs(
Expression expression, Predicate<Expression> predicate) {
Seq<Expression> seq =
expression.collect(
new PartialFunction<Expression, Expression>() {
@Override
public Expression apply(Expression expr) {
return expr;
}

@Override
public boolean isDefinedAt(Expression expr) {
return predicate.test(expr);
}
});
return toJavaList(seq);
}

private static <T> List<T> toJavaList(Seq<T> seq) {
return seqAsJavaListConverter(seq).asJava();
}
Expand Down
Loading

0 comments on commit 3ed0597

Please sign in to comment.