diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java index 6120fdffadb8..9970b44dc2fd 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkV2Filters.java @@ -45,17 +45,18 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expression.Operation; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.expressions.UnboundPredicate; import org.apache.iceberg.expressions.UnboundTerm; -import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.spark.functions.SparkFunctions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.util.NaNUtil; +import org.apache.iceberg.util.Pair; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc; @@ -67,6 +68,9 @@ public class SparkV2Filters { + private static final Set FUNCTIONS_SUPPORT_CONVERT = + ImmutableSet.of("years", "months", "days", "hours", "bucket", "truncate"); + private static final String TRUE = "ALWAYS_TRUE"; private static final String FALSE = "ALWAYS_FALSE"; private static final String EQ = "="; @@ -111,16 +115,16 @@ public static Expression convert(Predicate[] predicates) { for (Predicate predicate : predicates) { Expression converted = convert(predicate); Preconditions.checkArgument( - converted != null, "Cannot convert predicate to Iceberg: %s", predicate); + converted != null, "Cannot convert Spark predicate to Iceberg expression: %s", predicate); expression = Expressions.and(expression, converted); } + return expression; } @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"}) public static Expression convert(Predicate predicate) { Operation op = FILTERS.get(predicate.name()); - UnboundTerm term = null; if (op != null) { switch (op) { case TRUE: @@ -130,79 +134,57 @@ public static Expression convert(Predicate predicate) { return Expressions.alwaysFalse(); case IS_NULL: - if (!couldConvert(child(predicate))) { - return null; - } - - term = toTerm(child(predicate)); - if (term == null) { - return null; - } - - return isNull(term); + return couldConvert(child(predicate)) ? isNull(toTerm(child(predicate))) : null; case NOT_NULL: - if (!couldConvert(child(predicate))) { - return null; - } - - term = toTerm(child(predicate)); - if (term == null) { - return null; - } - - return notNull(term); + return couldConvert(child(predicate)) ? notNull(toTerm(child(predicate))) : null; case LT: - PredicateChildren ltChildren = predicateChildren(predicate); - if (ltChildren == null) { - return null; - } - - if (ltChildren.termOnLeft) { - return lessThan(ltChildren.term, ltChildren.value); + if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return lessThan(term, convertLiteral(rightChild(predicate))); + } else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return greaterThan(term, convertLiteral(leftChild(predicate))); } else { - return greaterThan(ltChildren.term, ltChildren.value); - } - - case LT_EQ: - PredicateChildren ltEqChildren = predicateChildren(predicate); - if (ltEqChildren == null) { return null; } - if (ltEqChildren.termOnLeft) { - return lessThanOrEqual(ltEqChildren.term, ltEqChildren.value); + case LT_EQ: + if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return lessThanOrEqual(term, convertLiteral(rightChild(predicate))); + } else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return greaterThanOrEqual(term, convertLiteral(leftChild(predicate))); } else { - return greaterThanOrEqual(ltEqChildren.term, ltEqChildren.value); - } - - case GT: - PredicateChildren gtChildren = predicateChildren(predicate); - if (gtChildren == null) { return null; } - if (gtChildren.termOnLeft) { - return greaterThan(gtChildren.term, gtChildren.value); + case GT: + if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return greaterThan(term, convertLiteral(rightChild(predicate))); + } else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return lessThan(term, convertLiteral(leftChild(predicate))); } else { - return lessThan(gtChildren.term, gtChildren.value); - } - - case GT_EQ: - PredicateChildren gtEqChildren = predicateChildren(predicate); - if (gtEqChildren == null) { return null; } - if (gtEqChildren.termOnLeft) { - return greaterThanOrEqual(gtEqChildren.term, gtEqChildren.value); + case GT_EQ: + if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) { + UnboundTerm term = toTerm(leftChild(predicate)); + return greaterThanOrEqual(term, convertLiteral(rightChild(predicate))); + } else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) { + UnboundTerm term = toTerm(rightChild(predicate)); + return lessThanOrEqual(term, convertLiteral(leftChild(predicate))); } else { - return lessThanOrEqual(gtEqChildren.term, gtEqChildren.value); + return null; } case EQ: // used for both eq and null-safe-eq - PredicateChildren eqChildren = predicateChildren(predicate); + Pair, Object> eqChildren = predicateChildren(predicate); if (eqChildren == null) { return null; } @@ -210,45 +192,40 @@ public static Expression convert(Predicate predicate) { if (predicate.name().equals(EQ)) { // comparison with null in normal equality is always null. this is probably a mistake. Preconditions.checkNotNull( - eqChildren.value, + eqChildren.second(), "Expression is always false (eq is not null-safe): %s", predicate); } - return handleEqual(eqChildren.term, eqChildren.value); + return handleEqual(eqChildren.first(), eqChildren.second()); case NOT_EQ: - PredicateChildren notEqChildren = predicateChildren(predicate); + Pair, Object> notEqChildren = predicateChildren(predicate); if (notEqChildren == null) { return null; } // comparison with null in normal equality is always null. this is probably a mistake. Preconditions.checkNotNull( - notEqChildren.value, + notEqChildren.second(), "Expression is always false (notEq is not null-safe): %s", predicate); - return handleNotEqual(notEqChildren.term, notEqChildren.value); + return handleNotEqual(notEqChildren.first(), notEqChildren.second()); case IN: - if (!isSupportedInPredicate(predicate)) { - return null; - } - - term = toTerm(childAtIndex(predicate, 0)); - if (term == null) { + if (isSupportedInPredicate(predicate)) { + return in( + toTerm(childAtIndex(predicate, 0)), + Arrays.stream(predicate.children()) + .skip(1) + .map(val -> convertLiteral(((Literal) val))) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + } else { return null; } - return in( - term, - Arrays.stream(predicate.children()) - .skip(1) - .map(val -> convertLiteral(((Literal) val))) - .filter(Objects::nonNull) - .collect(Collectors.toList())); - case NOT: Not notPredicate = (Not) predicate; Predicate childPredicate = notPredicate.child(); @@ -256,20 +233,15 @@ public static Expression convert(Predicate predicate) { // infer an extra notNull predicate for Spark NOT IN filters // as Iceberg expressions don't follow the 3-value SQL boolean logic // col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg - term = toTerm(childAtIndex(childPredicate, 0)); - if (term == null) { - return null; - } - Expression notIn = notIn( - term, + toTerm(childAtIndex(childPredicate, 0)), Arrays.stream(childPredicate.children()) .skip(1) .map(val -> convertLiteral(((Literal) val))) .filter(Objects::nonNull) .collect(Collectors.toList())); - return and(notNull(term), notIn); + return and(notNull(toTerm(childAtIndex(childPredicate, 0))), notIn); } else if (hasNoInFilter(childPredicate)) { Expression child = convert(childPredicate); if (child != null) { @@ -309,16 +281,16 @@ public static Expression convert(Predicate predicate) { return null; } - private static PredicateChildren predicateChildren(Predicate predicate) { + private static Pair, Object> predicateChildren(Predicate predicate) { if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) { UnboundTerm term = toTerm(leftChild(predicate)); Object value = convertLiteral(rightChild(predicate)); - return new PredicateChildren(term, value, true); + return Pair.of(term, value); } else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) { UnboundTerm term = toTerm(rightChild(predicate)); Object value = convertLiteral(leftChild(predicate)); - return new PredicateChildren(term, value, false); + return Pair.of(term, value); } else { return null; @@ -366,7 +338,7 @@ private static boolean isSystemFunc(org.apache.spark.sql.connector.expressions.E if (expr instanceof UserDefinedScalarFunc) { UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr; return udf.canonicalName().startsWith("iceberg") - && SparkFunctions.list().contains(udf.name()) + && FUNCTIONS_SUPPORT_CONVERT.contains(udf.name()) && Arrays.stream(udf.children()).allMatch(child -> isLiteral(child) || isRef(child)); } @@ -434,89 +406,66 @@ private static boolean isSupportedInPredicate(Predicate predicate) { } } + /** Should be called after {@link #couldConvert} passed */ private static UnboundTerm toTerm(T input) { if (input instanceof NamedReference) { return Expressions.ref(SparkUtil.toColumnName((NamedReference) input)); - - } else if (input instanceof UserDefinedScalarFunc) { - return udfToTerm((UserDefinedScalarFunc) input); - } else { - - return null; + return udfToTerm((UserDefinedScalarFunc) input); } } - @VisibleForTesting - @SuppressWarnings("unchecked") - static UnboundTerm udfToTerm(UserDefinedScalarFunc udf) { + private static UnboundTerm udfToTerm(UserDefinedScalarFunc udf) { switch (udf.name().toLowerCase(Locale.ROOT)) { case "years": Preconditions.checkArgument( udf.children().length == 1, "years function should have only one children (column)"); - if (isRef(udf.children()[0])) { - return year(SparkUtil.toColumnName((NamedReference) udf.children()[0])); - } - - return null; + Preconditions.checkArgument( + isRef(udf.children()[0]), + "The child of years function should be type of NamedReference"); + return year(SparkUtil.toColumnName((NamedReference) udf.children()[0])); case "months": Preconditions.checkArgument( udf.children().length == 1, "months function should have only one children (column)"); - if (isRef(udf.children()[0])) { - return month(SparkUtil.toColumnName((NamedReference) udf.children()[0])); - } - - return null; + Preconditions.checkArgument( + isRef(udf.children()[0]), + "The child of months function should be type of NamedReference"); + return month(SparkUtil.toColumnName((NamedReference) udf.children()[0])); case "days": Preconditions.checkArgument( udf.children().length == 1, "days function should have only one children (column)"); - if (isRef(udf.children()[0])) { - return day(SparkUtil.toColumnName((NamedReference) udf.children()[0])); - } - - return null; + Preconditions.checkArgument( + isRef(udf.children()[0]), + "The child of days function should be type of NamedReference"); + return day(SparkUtil.toColumnName((NamedReference) udf.children()[0])); case "hours": Preconditions.checkArgument( udf.children().length == 1, "hours function should have only one children (colum)"); - if (isRef(udf.children()[0])) { - return hour(SparkUtil.toColumnName((NamedReference) udf.children()[0])); - } - - return null; + Preconditions.checkArgument( + isRef(udf.children()[0]), + "The child of hours function should be type of NamedReference"); + return hour(SparkUtil.toColumnName((NamedReference) udf.children()[0])); case "bucket": Preconditions.checkArgument( udf.children().length == 2, "bucket function should have two children (numBuckets and column)"); - if (isLiteral(udf.children()[0]) && isRef(udf.children()[1])) { - int numBuckets = (Integer) convertLiteral((Literal) udf.children()[0]); - return bucket(SparkUtil.toColumnName((NamedReference) udf.children()[1]), numBuckets); - } - - return null; + Preconditions.checkArgument( + isLiteral(udf.children()[0]) && isRef(udf.children()[1]), + "The children's type of bucket function should be Literal and NamedReference"); + int numBuckets = (Integer) convertLiteral((Literal) udf.children()[0]); + return bucket(SparkUtil.toColumnName((NamedReference) udf.children()[1]), numBuckets); case "truncate": Preconditions.checkArgument( udf.children().length == 2, "truncate function should have two children (width and column)"); - if (isLiteral(udf.children()[0]) && isRef(udf.children()[1])) { - int width = (Integer) convertLiteral((Literal) udf.children()[0]); - return truncate(SparkUtil.toColumnName((NamedReference) udf.children()[1]), width); - } - - return null; + Preconditions.checkArgument( + isLiteral(udf.children()[0]) && isRef(udf.children()[1]), + "The children's type of truncate function should be Literal and NamedReference"); + int width = (Integer) convertLiteral((Literal) udf.children()[0]); + return truncate(SparkUtil.toColumnName((NamedReference) udf.children()[1]), width); default: - return null; - } - } - - private static class PredicateChildren { - private final UnboundTerm term; - private final Object value; - private final boolean termOnLeft; - - PredicateChildren(UnboundTerm term, Object value, boolean termOnLeft) { - this.term = term; - this.value = value; - this.termOnLeft = termOnLeft; + // Should not reach here + throw new RuntimeException("Unsupported system function: " + udf.canonicalName()); } } } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java index 835efbad0211..16a3be5d2bff 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkV2Filters.java @@ -29,6 +29,7 @@ import org.apache.iceberg.spark.functions.BucketFunction; import org.apache.iceberg.spark.functions.DaysFunction; import org.apache.iceberg.spark.functions.HoursFunction; +import org.apache.iceberg.spark.functions.IcebergVersionFunction; import org.apache.iceberg.spark.functions.MonthsFunction; import org.apache.iceberg.spark.functions.TruncateFunction; import org.apache.iceberg.spark.functions.YearsFunction; @@ -44,6 +45,7 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; import org.assertj.core.api.Assertions; import org.junit.Assert; @@ -562,6 +564,22 @@ public void testTruncate() { testUDF(udf, Expressions.truncate("col1", 6), "prefix", DataTypes.StringType); } + @Test + public void testUnsupportedUDFConvert() { + ScalarFunction icebergVersionFunc = + (ScalarFunction) new IcebergVersionFunction().bind(new StructType()); + UserDefinedScalarFunc udf = + new UserDefinedScalarFunc( + icebergVersionFunc.name(), + icebergVersionFunc.canonicalName(), + new org.apache.spark.sql.connector.expressions.Expression[] {}); + LiteralValue literalValue = new LiteralValue("1.3.0", DataTypes.StringType); + Predicate predicate = new Predicate("=", expressions(udf, literalValue)); + + Expression icebergExpr = SparkV2Filters.convert(predicate); + Assertions.assertThat(icebergExpr).isNull(); + } + private void testUDF( org.apache.spark.sql.connector.expressions.Expression expression, UnboundTerm item,