Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ConeyLiu committed Jul 1, 2023
1 parent 87fcc11 commit b8cf269
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,19 +252,19 @@ public static Expression convert(Predicate predicate) {
}

private static Pair<UnboundTerm<Object>, Object> predicateChildren(Predicate predicate) {
Object value;
UnboundTerm<Object> term;
if (isRef(leftChild(predicate)) && isLiteral(rightChild(predicate))) {
term = ref(SparkUtil.toColumnName(leftChild(predicate)));
value = convertLiteral(rightChild(predicate));
UnboundTerm<Object> term = ref(SparkUtil.toColumnName(leftChild(predicate)));
Object value = convertLiteral(rightChild(predicate));
return Pair.of(term, value);

} else if (isRef(rightChild(predicate)) && isLiteral(leftChild(predicate))) {
term = ref(SparkUtil.toColumnName(rightChild(predicate)));
value = convertLiteral(leftChild(predicate));
UnboundTerm<Object> term = ref(SparkUtil.toColumnName(rightChild(predicate)));
Object value = convertLiteral(leftChild(predicate));
return Pair.of(term, value);

} else {
return null;
}

return Pair.of(term, value);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -314,9 +314,7 @@ private static Object convertLiteral(Literal<?> literal) {
private static UnboundPredicate<Object> handleEqual(UnboundTerm<Object> term, Object value) {
if (value == null) {
return isNull(term);
}

if (NaNUtil.isNaN(value)) {
} else if (NaNUtil.isNaN(value)) {
return isNaN(term);
} else {
return equal(term, value);
Expand All @@ -326,9 +324,7 @@ private static UnboundPredicate<Object> handleEqual(UnboundTerm<Object> term, Ob
private static UnboundPredicate<Object> handleNotEqual(UnboundTerm<Object> term, Object value) {
if (value == null) {
return notNull(term);
}

if (NaNUtil.isNaN(value)) {
} else if (NaNUtil.isNaN(value)) {
return notNaN(term);
} else {
return notEqual(term, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@

public class TestSparkV2Filters {

@SuppressWarnings("checkstyle:MethodLength")
@Test
public void testV2Filters() {
Map<String, String> attrMap = Maps.newHashMap();
Expand All @@ -49,157 +48,150 @@ public void testV2Filters() {
attrMap.put("`d`.b.`dd```", "d.b.dd`");
attrMap.put("a.`aa```.c", "a.aa`.c");

attrMap.forEach(
(quoted, unquoted) -> {
NamedReference namedReference = FieldReference.apply(quoted);
org.apache.spark.sql.connector.expressions.Expression[] attrOnly =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference};

LiteralValue value = new LiteralValue(1, DataTypes.IntegerType);
org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value};
org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference};

Predicate isNull = new Predicate("IS_NULL", attrOnly);
Expression expectedIsNull = Expressions.isNull(unquoted);
Expression actualIsNull = SparkV2Filters.convert(isNull);
Assert.assertEquals(
"IsNull must match", expectedIsNull.toString(), actualIsNull.toString());

Predicate isNotNull = new Predicate("IS_NOT_NULL", attrOnly);
Expression expectedIsNotNull = Expressions.notNull(unquoted);
Expression actualIsNotNull = SparkV2Filters.convert(isNotNull);
Assert.assertEquals(
"IsNotNull must match", expectedIsNotNull.toString(), actualIsNotNull.toString());

Predicate lt1 = new Predicate("<", attrAndValue);
Expression expectedLt1 = Expressions.lessThan(unquoted, 1);
Expression actualLt1 = SparkV2Filters.convert(lt1);
Assert.assertEquals("LessThan must match", expectedLt1.toString(), actualLt1.toString());

Predicate lt2 = new Predicate("<", valueAndAttr);
Expression expectedLt2 = Expressions.greaterThan(unquoted, 1);
Expression actualLt2 = SparkV2Filters.convert(lt2);
Assert.assertEquals("LessThan must match", expectedLt2.toString(), actualLt2.toString());

Predicate ltEq1 = new Predicate("<=", attrAndValue);
Expression expectedLtEq1 = Expressions.lessThanOrEqual(unquoted, 1);
Expression actualLtEq1 = SparkV2Filters.convert(ltEq1);
Assert.assertEquals(
"LessThanOrEqual must match", expectedLtEq1.toString(), actualLtEq1.toString());

Predicate ltEq2 = new Predicate("<=", valueAndAttr);
Expression expectedLtEq2 = Expressions.greaterThanOrEqual(unquoted, 1);
Expression actualLtEq2 = SparkV2Filters.convert(ltEq2);
Assert.assertEquals(
"LessThanOrEqual must match", expectedLtEq2.toString(), actualLtEq2.toString());

Predicate gt1 = new Predicate(">", attrAndValue);
Expression expectedGt1 = Expressions.greaterThan(unquoted, 1);
Expression actualGt1 = SparkV2Filters.convert(gt1);
Assert.assertEquals(
"GreaterThan must match", expectedGt1.toString(), actualGt1.toString());

Predicate gt2 = new Predicate(">", valueAndAttr);
Expression expectedGt2 = Expressions.lessThan(unquoted, 1);
Expression actualGt2 = SparkV2Filters.convert(gt2);
Assert.assertEquals(
"GreaterThan must match", expectedGt2.toString(), actualGt2.toString());

Predicate gtEq1 = new Predicate(">=", attrAndValue);
Expression expectedGtEq1 = Expressions.greaterThanOrEqual(unquoted, 1);
Expression actualGtEq1 = SparkV2Filters.convert(gtEq1);
Assert.assertEquals(
"GreaterThanOrEqual must match", expectedGtEq1.toString(), actualGtEq1.toString());

Predicate gtEq2 = new Predicate(">=", valueAndAttr);
Expression expectedGtEq2 = Expressions.lessThanOrEqual(unquoted, 1);
Expression actualGtEq2 = SparkV2Filters.convert(gtEq2);
Assert.assertEquals(
"GreaterThanOrEqual must match", expectedGtEq2.toString(), actualGtEq2.toString());

Predicate eq1 = new Predicate("=", attrAndValue);
Expression expectedEq1 = Expressions.equal(unquoted, 1);
Expression actualEq1 = SparkV2Filters.convert(eq1);
Assert.assertEquals("EqualTo must match", expectedEq1.toString(), actualEq1.toString());

Predicate eq2 = new Predicate("=", valueAndAttr);
Expression expectedEq2 = Expressions.equal(unquoted, 1);
Expression actualEq2 = SparkV2Filters.convert(eq2);
Assert.assertEquals("EqualTo must match", expectedEq2.toString(), actualEq2.toString());

Predicate notEq1 = new Predicate("<>", attrAndValue);
Expression expectedNotEq1 = Expressions.notEqual(unquoted, 1);
Expression actualNotEq1 = SparkV2Filters.convert(notEq1);
Assert.assertEquals(
"NotEqualTo must match", expectedNotEq1.toString(), actualNotEq1.toString());

Predicate notEq2 = new Predicate("<>", valueAndAttr);
Expression expectedNotEq2 = Expressions.notEqual(unquoted, 1);
Expression actualNotEq2 = SparkV2Filters.convert(notEq2);
Assert.assertEquals(
"NotEqualTo must match", expectedNotEq2.toString(), actualNotEq2.toString());

Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1);
Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
Assert.assertEquals(
"EqualNullSafe must match",
expectedEqNullSafe1.toString(),
actualEqNullSafe1.toString());

Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr);
Expression expectedEqNullSafe2 = Expressions.equal(unquoted, 1);
Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2);
Assert.assertEquals(
"EqualNullSafe must match",
expectedEqNullSafe2.toString(),
actualEqNullSafe2.toString());

LiteralValue str =
new LiteralValue(UTF8String.fromString("iceberg"), DataTypes.StringType);
org.apache.spark.sql.connector.expressions.Expression[] attrAndStr =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, str};
Predicate startsWith = new Predicate("STARTS_WITH", attrAndStr);
Expression expectedStartsWith = Expressions.startsWith(unquoted, "iceberg");
Expression actualStartsWith = SparkV2Filters.convert(startsWith);
Assert.assertEquals(
"StartsWith must match", expectedStartsWith.toString(), actualStartsWith.toString());

Predicate in = new Predicate("IN", attrAndValue);
Expression expectedIn = Expressions.in(unquoted, 1);
Expression actualIn = SparkV2Filters.convert(in);
Assert.assertEquals("In must match", expectedIn.toString(), actualIn.toString());

Predicate and = new And(lt1, eq1);
Expression expectedAnd = Expressions.and(expectedLt1, expectedEq1);
Expression actualAnd = SparkV2Filters.convert(and);
Assert.assertEquals("And must match", expectedAnd.toString(), actualAnd.toString());

org.apache.spark.sql.connector.expressions.Expression[] attrAndAttr =
new org.apache.spark.sql.connector.expressions.Expression[] {
namedReference, namedReference
};
Predicate invalid = new Predicate("<", attrAndAttr);
Predicate andWithInvalidLeft = new And(invalid, eq1);
Expression convertedAnd = SparkV2Filters.convert(andWithInvalidLeft);
Assert.assertEquals("And must match", convertedAnd, null);

Predicate or = new Or(lt1, eq1);
Expression expectedOr = Expressions.or(expectedLt1, expectedEq1);
Expression actualOr = SparkV2Filters.convert(or);
Assert.assertEquals("Or must match", expectedOr.toString(), actualOr.toString());

Predicate orWithInvalidLeft = new Or(invalid, eq1);
Expression convertedOr = SparkV2Filters.convert(orWithInvalidLeft);
Assert.assertEquals("Or must match", convertedOr, null);

Predicate not = new Not(lt1);
Expression expectedNot = Expressions.not(expectedLt1);
Expression actualNot = SparkV2Filters.convert(not);
Assert.assertEquals("Not must match", expectedNot.toString(), actualNot.toString());
});
attrMap.forEach(this::testV2Filter);
}

private void testV2Filter(String quoted, String unquoted) {
NamedReference namedReference = FieldReference.apply(quoted);
org.apache.spark.sql.connector.expressions.Expression[] attrOnly =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference};

LiteralValue value = new LiteralValue(1, DataTypes.IntegerType);
org.apache.spark.sql.connector.expressions.Expression[] attrAndValue =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, value};
org.apache.spark.sql.connector.expressions.Expression[] valueAndAttr =
new org.apache.spark.sql.connector.expressions.Expression[] {value, namedReference};

Predicate isNull = new Predicate("IS_NULL", attrOnly);
Expression expectedIsNull = Expressions.isNull(unquoted);
Expression actualIsNull = SparkV2Filters.convert(isNull);
Assert.assertEquals("IsNull must match", expectedIsNull.toString(), actualIsNull.toString());

Predicate isNotNull = new Predicate("IS_NOT_NULL", attrOnly);
Expression expectedIsNotNull = Expressions.notNull(unquoted);
Expression actualIsNotNull = SparkV2Filters.convert(isNotNull);
Assert.assertEquals(
"IsNotNull must match", expectedIsNotNull.toString(), actualIsNotNull.toString());

Predicate lt1 = new Predicate("<", attrAndValue);
Expression expectedLt1 = Expressions.lessThan(unquoted, 1);
Expression actualLt1 = SparkV2Filters.convert(lt1);
Assert.assertEquals("LessThan must match", expectedLt1.toString(), actualLt1.toString());

Predicate lt2 = new Predicate("<", valueAndAttr);
Expression expectedLt2 = Expressions.greaterThan(unquoted, 1);
Expression actualLt2 = SparkV2Filters.convert(lt2);
Assert.assertEquals("LessThan must match", expectedLt2.toString(), actualLt2.toString());

Predicate ltEq1 = new Predicate("<=", attrAndValue);
Expression expectedLtEq1 = Expressions.lessThanOrEqual(unquoted, 1);
Expression actualLtEq1 = SparkV2Filters.convert(ltEq1);
Assert.assertEquals(
"LessThanOrEqual must match", expectedLtEq1.toString(), actualLtEq1.toString());

Predicate ltEq2 = new Predicate("<=", valueAndAttr);
Expression expectedLtEq2 = Expressions.greaterThanOrEqual(unquoted, 1);
Expression actualLtEq2 = SparkV2Filters.convert(ltEq2);
Assert.assertEquals(
"LessThanOrEqual must match", expectedLtEq2.toString(), actualLtEq2.toString());

Predicate gt1 = new Predicate(">", attrAndValue);
Expression expectedGt1 = Expressions.greaterThan(unquoted, 1);
Expression actualGt1 = SparkV2Filters.convert(gt1);
Assert.assertEquals("GreaterThan must match", expectedGt1.toString(), actualGt1.toString());

Predicate gt2 = new Predicate(">", valueAndAttr);
Expression expectedGt2 = Expressions.lessThan(unquoted, 1);
Expression actualGt2 = SparkV2Filters.convert(gt2);
Assert.assertEquals("GreaterThan must match", expectedGt2.toString(), actualGt2.toString());

Predicate gtEq1 = new Predicate(">=", attrAndValue);
Expression expectedGtEq1 = Expressions.greaterThanOrEqual(unquoted, 1);
Expression actualGtEq1 = SparkV2Filters.convert(gtEq1);
Assert.assertEquals(
"GreaterThanOrEqual must match", expectedGtEq1.toString(), actualGtEq1.toString());

Predicate gtEq2 = new Predicate(">=", valueAndAttr);
Expression expectedGtEq2 = Expressions.lessThanOrEqual(unquoted, 1);
Expression actualGtEq2 = SparkV2Filters.convert(gtEq2);
Assert.assertEquals(
"GreaterThanOrEqual must match", expectedGtEq2.toString(), actualGtEq2.toString());

Predicate eq1 = new Predicate("=", attrAndValue);
Expression expectedEq1 = Expressions.equal(unquoted, 1);
Expression actualEq1 = SparkV2Filters.convert(eq1);
Assert.assertEquals("EqualTo must match", expectedEq1.toString(), actualEq1.toString());

Predicate eq2 = new Predicate("=", valueAndAttr);
Expression expectedEq2 = Expressions.equal(unquoted, 1);
Expression actualEq2 = SparkV2Filters.convert(eq2);
Assert.assertEquals("EqualTo must match", expectedEq2.toString(), actualEq2.toString());

Predicate notEq1 = new Predicate("<>", attrAndValue);
Expression expectedNotEq1 = Expressions.notEqual(unquoted, 1);
Expression actualNotEq1 = SparkV2Filters.convert(notEq1);
Assert.assertEquals(
"NotEqualTo must match", expectedNotEq1.toString(), actualNotEq1.toString());

Predicate notEq2 = new Predicate("<>", valueAndAttr);
Expression expectedNotEq2 = Expressions.notEqual(unquoted, 1);
Expression actualNotEq2 = SparkV2Filters.convert(notEq2);
Assert.assertEquals(
"NotEqualTo must match", expectedNotEq2.toString(), actualNotEq2.toString());

Predicate eqNullSafe1 = new Predicate("<=>", attrAndValue);
Expression expectedEqNullSafe1 = Expressions.equal(unquoted, 1);
Expression actualEqNullSafe1 = SparkV2Filters.convert(eqNullSafe1);
Assert.assertEquals(
"EqualNullSafe must match", expectedEqNullSafe1.toString(), actualEqNullSafe1.toString());

Predicate eqNullSafe2 = new Predicate("<=>", valueAndAttr);
Expression expectedEqNullSafe2 = Expressions.equal(unquoted, 1);
Expression actualEqNullSafe2 = SparkV2Filters.convert(eqNullSafe2);
Assert.assertEquals(
"EqualNullSafe must match", expectedEqNullSafe2.toString(), actualEqNullSafe2.toString());

LiteralValue str = new LiteralValue(UTF8String.fromString("iceberg"), DataTypes.StringType);
org.apache.spark.sql.connector.expressions.Expression[] attrAndStr =
new org.apache.spark.sql.connector.expressions.Expression[] {namedReference, str};
Predicate startsWith = new Predicate("STARTS_WITH", attrAndStr);
Expression expectedStartsWith = Expressions.startsWith(unquoted, "iceberg");
Expression actualStartsWith = SparkV2Filters.convert(startsWith);
Assert.assertEquals(
"StartsWith must match", expectedStartsWith.toString(), actualStartsWith.toString());

Predicate in = new Predicate("IN", attrAndValue);
Expression expectedIn = Expressions.in(unquoted, 1);
Expression actualIn = SparkV2Filters.convert(in);
Assert.assertEquals("In must match", expectedIn.toString(), actualIn.toString());

Predicate and = new And(lt1, eq1);
Expression expectedAnd = Expressions.and(expectedLt1, expectedEq1);
Expression actualAnd = SparkV2Filters.convert(and);
Assert.assertEquals("And must match", expectedAnd.toString(), actualAnd.toString());

org.apache.spark.sql.connector.expressions.Expression[] attrAndAttr =
new org.apache.spark.sql.connector.expressions.Expression[] {
namedReference, namedReference
};
Predicate invalid = new Predicate("<", attrAndAttr);
Predicate andWithInvalidLeft = new And(invalid, eq1);
Expression convertedAnd = SparkV2Filters.convert(andWithInvalidLeft);
Assert.assertEquals("And must match", convertedAnd, null);

Predicate or = new Or(lt1, eq1);
Expression expectedOr = Expressions.or(expectedLt1, expectedEq1);
Expression actualOr = SparkV2Filters.convert(or);
Assert.assertEquals("Or must match", expectedOr.toString(), actualOr.toString());

Predicate orWithInvalidLeft = new Or(invalid, eq1);
Expression convertedOr = SparkV2Filters.convert(orWithInvalidLeft);
Assert.assertEquals("Or must match", convertedOr, null);

Predicate not = new Not(lt1);
Expression expectedNot = Expressions.not(expectedLt1);
Expression actualNot = SparkV2Filters.convert(not);
Assert.assertEquals("Not must match", expectedNot.toString(), actualNot.toString());
}

@Test
Expand Down

0 comments on commit b8cf269

Please sign in to comment.