Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ConeyLiu committed Jul 6, 2023
1 parent eb87b00 commit 89424f7
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,18 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expression.Operation;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.UnboundPredicate;
import org.apache.iceberg.expressions.UnboundTerm;
import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.spark.functions.SparkFunctions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
import org.apache.iceberg.util.NaNUtil;
import org.apache.iceberg.util.Pair;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
Expand All @@ -67,6 +68,9 @@

public class SparkV2Filters {

private static final Set<String> FUNCTIONS_SUPPORT_CONVERT =
ImmutableSet.of("years", "months", "days", "hours", "bucket", "truncate");

private static final String TRUE = "ALWAYS_TRUE";
private static final String FALSE = "ALWAYS_FALSE";
private static final String EQ = "=";
Expand Down Expand Up @@ -111,16 +115,16 @@ public static Expression convert(Predicate[] predicates) {
for (Predicate predicate : predicates) {
Expression converted = convert(predicate);
Preconditions.checkArgument(
converted != null, "Cannot convert predicate to Iceberg: %s", predicate);
converted != null, "Cannot convert Spark predicate to Iceberg expression: %s", predicate);
expression = Expressions.and(expression, converted);
}

return expression;
}

@SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"})
public static Expression convert(Predicate predicate) {
Operation op = FILTERS.get(predicate.name());
UnboundTerm<Object> term = null;
if (op != null) {
switch (op) {
case TRUE:
Expand All @@ -130,146 +134,114 @@ public static Expression convert(Predicate predicate) {
return Expressions.alwaysFalse();

case IS_NULL:
if (!couldConvert(child(predicate))) {
return null;
}

term = toTerm(child(predicate));
if (term == null) {
return null;
}

return isNull(term);
return couldConvert(child(predicate)) ? isNull(toTerm(child(predicate))) : null;

case NOT_NULL:
if (!couldConvert(child(predicate))) {
return null;
}

term = toTerm(child(predicate));
if (term == null) {
return null;
}

return notNull(term);
return couldConvert(child(predicate)) ? notNull(toTerm(child(predicate))) : null;

case LT:
PredicateChildren ltChildren = predicateChildren(predicate);
if (ltChildren == null) {
return null;
}

if (ltChildren.termOnLeft) {
return lessThan(ltChildren.term, ltChildren.value);
if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return lessThan(term, convertLiteral(rightChild(predicate)));
} else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return greaterThan(term, convertLiteral(leftChild(predicate)));
} else {
return greaterThan(ltChildren.term, ltChildren.value);
}

case LT_EQ:
PredicateChildren ltEqChildren = predicateChildren(predicate);
if (ltEqChildren == null) {
return null;
}

if (ltEqChildren.termOnLeft) {
return lessThanOrEqual(ltEqChildren.term, ltEqChildren.value);
case LT_EQ:
if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return lessThanOrEqual(term, convertLiteral(rightChild(predicate)));
} else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return greaterThanOrEqual(term, convertLiteral(leftChild(predicate)));
} else {
return greaterThanOrEqual(ltEqChildren.term, ltEqChildren.value);
}

case GT:
PredicateChildren gtChildren = predicateChildren(predicate);
if (gtChildren == null) {
return null;
}

if (gtChildren.termOnLeft) {
return greaterThan(gtChildren.term, gtChildren.value);
case GT:
if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return greaterThan(term, convertLiteral(rightChild(predicate)));
} else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return lessThan(term, convertLiteral(leftChild(predicate)));
} else {
return lessThan(gtChildren.term, gtChildren.value);
}

case GT_EQ:
PredicateChildren gtEqChildren = predicateChildren(predicate);
if (gtEqChildren == null) {
return null;
}

if (gtEqChildren.termOnLeft) {
return greaterThanOrEqual(gtEqChildren.term, gtEqChildren.value);
case GT_EQ:
if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
return greaterThanOrEqual(term, convertLiteral(rightChild(predicate)));
} else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
return lessThanOrEqual(term, convertLiteral(leftChild(predicate)));
} else {
return lessThanOrEqual(gtEqChildren.term, gtEqChildren.value);
return null;
}

case EQ: // used for both eq and null-safe-eq
PredicateChildren eqChildren = predicateChildren(predicate);
Pair<UnboundTerm<Object>, Object> eqChildren = predicateChildren(predicate);
if (eqChildren == null) {
return null;
}

if (predicate.name().equals(EQ)) {
// comparison with null in normal equality is always null. this is probably a mistake.
Preconditions.checkNotNull(
eqChildren.value,
eqChildren.second(),
"Expression is always false (eq is not null-safe): %s",
predicate);
}

return handleEqual(eqChildren.term, eqChildren.value);
return handleEqual(eqChildren.first(), eqChildren.second());

case NOT_EQ:
PredicateChildren notEqChildren = predicateChildren(predicate);
Pair<UnboundTerm<Object>, Object> notEqChildren = predicateChildren(predicate);
if (notEqChildren == null) {
return null;
}

// comparison with null in normal equality is always null. this is probably a mistake.
Preconditions.checkNotNull(
notEqChildren.value,
notEqChildren.second(),
"Expression is always false (notEq is not null-safe): %s",
predicate);

return handleNotEqual(notEqChildren.term, notEqChildren.value);
return handleNotEqual(notEqChildren.first(), notEqChildren.second());

case IN:
if (!isSupportedInPredicate(predicate)) {
return null;
}

term = toTerm(childAtIndex(predicate, 0));
if (term == null) {
if (isSupportedInPredicate(predicate)) {
return in(
toTerm(childAtIndex(predicate, 0)),
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
} else {
return null;
}

return in(
term,
Arrays.stream(predicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));

case NOT:
Not notPredicate = (Not) predicate;
Predicate childPredicate = notPredicate.child();
if (childPredicate.name().equals(IN) && isSupportedInPredicate(childPredicate)) {
// infer an extra notNull predicate for Spark NOT IN filters
// as Iceberg expressions don't follow the 3-value SQL boolean logic
// col NOT IN (1, 2) in Spark is equal to notNull(col) && notIn(col, 1, 2) in Iceberg
term = toTerm(childAtIndex(childPredicate, 0));
if (term == null) {
return null;
}

Expression notIn =
notIn(
term,
toTerm(childAtIndex(childPredicate, 0)),
Arrays.stream(childPredicate.children())
.skip(1)
.map(val -> convertLiteral(((Literal<?>) val)))
.filter(Objects::nonNull)
.collect(Collectors.toList()));
return and(notNull(term), notIn);
return and(notNull(toTerm(childAtIndex(childPredicate, 0))), notIn);
} else if (hasNoInFilter(childPredicate)) {
Expression child = convert(childPredicate);
if (child != null) {
Expand Down Expand Up @@ -309,16 +281,16 @@ public static Expression convert(Predicate predicate) {
return null;
}

private static PredicateChildren predicateChildren(Predicate predicate) {
private static Pair<UnboundTerm<Object>, Object> predicateChildren(Predicate predicate) {
if (couldConvert(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
UnboundTerm<Object> term = toTerm(leftChild(predicate));
Object value = convertLiteral(rightChild(predicate));
return new PredicateChildren(term, value, true);
return Pair.of(term, value);

} else if (couldConvert(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
UnboundTerm<Object> term = toTerm(rightChild(predicate));
Object value = convertLiteral(leftChild(predicate));
return new PredicateChildren(term, value, false);
return Pair.of(term, value);

} else {
return null;
Expand Down Expand Up @@ -366,7 +338,7 @@ private static boolean isSystemFunc(org.apache.spark.sql.connector.expressions.E
if (expr instanceof UserDefinedScalarFunc) {
UserDefinedScalarFunc udf = (UserDefinedScalarFunc) expr;
return udf.canonicalName().startsWith("iceberg")
&& SparkFunctions.list().contains(udf.name())
&& FUNCTIONS_SUPPORT_CONVERT.contains(udf.name())
&& Arrays.stream(udf.children()).allMatch(child -> isLiteral(child) || isRef(child));
}

Expand Down Expand Up @@ -434,89 +406,66 @@ private static boolean isSupportedInPredicate(Predicate predicate) {
}
}

/** Should be called after {@link #couldConvert} passed */
private static <T> UnboundTerm<Object> toTerm(T input) {
if (input instanceof NamedReference) {
return Expressions.ref(SparkUtil.toColumnName((NamedReference) input));

} else if (input instanceof UserDefinedScalarFunc) {
return udfToTerm((UserDefinedScalarFunc) input);

} else {

return null;
return udfToTerm((UserDefinedScalarFunc) input);
}
}

@VisibleForTesting
@SuppressWarnings("unchecked")
static UnboundTerm<Object> udfToTerm(UserDefinedScalarFunc udf) {
private static UnboundTerm<Object> udfToTerm(UserDefinedScalarFunc udf) {
switch (udf.name().toLowerCase(Locale.ROOT)) {
case "years":
Preconditions.checkArgument(
udf.children().length == 1, "years function should have only one children (column)");
if (isRef(udf.children()[0])) {
return year(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
}

return null;
Preconditions.checkArgument(
isRef(udf.children()[0]),
"The child of years function should be type of NamedReference");
return year(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
case "months":
Preconditions.checkArgument(
udf.children().length == 1, "months function should have only one children (column)");
if (isRef(udf.children()[0])) {
return month(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
}

return null;
Preconditions.checkArgument(
isRef(udf.children()[0]),
"The child of months function should be type of NamedReference");
return month(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
case "days":
Preconditions.checkArgument(
udf.children().length == 1, "days function should have only one children (column)");
if (isRef(udf.children()[0])) {
return day(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
}

return null;
Preconditions.checkArgument(
isRef(udf.children()[0]),
"The child of days function should be type of NamedReference");
return day(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
case "hours":
Preconditions.checkArgument(
udf.children().length == 1, "hours function should have only one children (colum)");
if (isRef(udf.children()[0])) {
return hour(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
}

return null;
Preconditions.checkArgument(
isRef(udf.children()[0]),
"The child of hours function should be type of NamedReference");
return hour(SparkUtil.toColumnName((NamedReference) udf.children()[0]));
case "bucket":
Preconditions.checkArgument(
udf.children().length == 2,
"bucket function should have two children (numBuckets and column)");
if (isLiteral(udf.children()[0]) && isRef(udf.children()[1])) {
int numBuckets = (Integer) convertLiteral((Literal<?>) udf.children()[0]);
return bucket(SparkUtil.toColumnName((NamedReference) udf.children()[1]), numBuckets);
}

return null;
Preconditions.checkArgument(
isLiteral(udf.children()[0]) && isRef(udf.children()[1]),
"The children's type of bucket function should be Literal and NamedReference");
int numBuckets = (Integer) convertLiteral((Literal<?>) udf.children()[0]);
return bucket(SparkUtil.toColumnName((NamedReference) udf.children()[1]), numBuckets);
case "truncate":
Preconditions.checkArgument(
udf.children().length == 2,
"truncate function should have two children (width and column)");
if (isLiteral(udf.children()[0]) && isRef(udf.children()[1])) {
int width = (Integer) convertLiteral((Literal<?>) udf.children()[0]);
return truncate(SparkUtil.toColumnName((NamedReference) udf.children()[1]), width);
}

return null;
Preconditions.checkArgument(
isLiteral(udf.children()[0]) && isRef(udf.children()[1]),
"The children's type of truncate function should be Literal and NamedReference");
int width = (Integer) convertLiteral((Literal<?>) udf.children()[0]);
return truncate(SparkUtil.toColumnName((NamedReference) udf.children()[1]), width);
default:
return null;
}
}

private static class PredicateChildren {
private final UnboundTerm<Object> term;
private final Object value;
private final boolean termOnLeft;

PredicateChildren(UnboundTerm<Object> term, Object value, boolean termOnLeft) {
this.term = term;
this.value = value;
this.termOnLeft = termOnLeft;
// Should not reach here
throw new RuntimeException("Unsupported system function: " + udf.canonicalName());
}
}
}
Loading

0 comments on commit 89424f7

Please sign in to comment.