Skip to content

Commit

Permalink
[SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue da…
Browse files Browse the repository at this point in the history
…ta type for In predicate (#535)

### What changes were proposed in this pull request?
When building V2 `In` Predicate in `V2ExpressionBuilder`, `InSet.dataType` (which is `BooleanType`) is used to build the `LiteralValue`, `InSet.child.dataType` should be used instead.

### Why are the changes needed?
bug fix

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
new test

Closes apache#37271 from huaxingao/inset.

Authored-by: huaxingao <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>

Signed-off-by: Dongjoon Hyun <[email protected]>
Co-authored-by: huaxingao <[email protected]>
  • Loading branch information
2 people authored and leejaywei committed Oct 18, 2022
1 parent 04f5c40 commit ff17bff
Showing 1 changed file with 226 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
val attrInts = Seq(
Expand Down Expand Up @@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
"a.b.cint" // three level nested field
))

test("SPARK-39784: translate binary expression") { attrInts
.foreach { case (attrInt, intColName) =>
val attrStrs = Seq(
$"cstr".string,
$"c.str".string,
GetStructField($"a".struct(StructType(
StructField("cint", IntegerType, nullable = true) ::
StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
GetStructField($"a".struct(StructType(
StructField("c.str", StringType, nullable = true) ::
StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
GetStructField($"a.b".struct(StructType(
StructField("cint1", IntegerType, nullable = true) ::
StructField("cint2", IntegerType, nullable = true) ::
StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
GetStructField($"a.b".struct(StructType(
StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
GetStructField(GetStructField($"a".struct(StructType(
StructField("cint1", IntegerType, nullable = true) ::
StructField("b", StructType(StructField("cstr", StringType, nullable = true) ::
StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None)
).zip(Seq(
"cstr",
"`c.str`", // single level field that contains `dot` in name
"a.cstr", // two level nested field
"a.`c.str`", // two level nested field, and nested level contains `dot`
"`a.b`.cstr", // two level nested field, and top level contains `dot`
"`a.b`.`c.str`", // two level nested field, and both levels contain `dot`
"a.b.cstr" // three level nested field
))

test("translate simple expression") { attrInts.zip(attrStrs)
.foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
testTranslateFilter(EqualTo(attrInt, 1),
Some(new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
testTranslateFilter(EqualTo(1, attrInt),
Expand Down Expand Up @@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
Some(new Predicate("<=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))
testTranslateFilter(LessThanOrEqual(1, attrInt),
Some(new Predicate(">=", Array(FieldReference(intColName), LiteralValue(1, IntegerType)))))

testTranslateFilter(IsNull(attrInt),
Some(new Predicate("IS_NULL", Array(FieldReference(intColName)))))
testTranslateFilter(IsNotNull(attrInt),
Some(new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))

testTranslateFilter(InSet(attrInt, Set(1, 2, 3)),
Some(new Predicate("IN", Array(FieldReference(intColName),
LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
LiteralValue(3, IntegerType)))))

testTranslateFilter(In(attrInt, Seq(1, 2, 3)),
Some(new Predicate("IN", Array(FieldReference(intColName),
LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
LiteralValue(3, IntegerType)))))

// cint > 1 AND cint < 10
testTranslateFilter(And(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)),
Some(new V2And(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType))))))

// cint >= 8 OR cint <= 2
testTranslateFilter(Or(
GreaterThanOrEqual(attrInt, 8),
LessThanOrEqual(attrInt, 2)),
Some(new V2Or(
new Predicate(">=", Array(FieldReference(intColName), LiteralValue(8, IntegerType))),
new Predicate("<=", Array(FieldReference(intColName), LiteralValue(2, IntegerType))))))

testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
Some(new V2Not(new Predicate(">=", Array(FieldReference(intColName),
LiteralValue(8, IntegerType))))))

testTranslateFilter(StartsWith(attrStr, "a"),
Some(new Predicate("STARTS_WITH", Array(FieldReference(strColName),
LiteralValue(UTF8String.fromString("a"), StringType)))))

testTranslateFilter(EndsWith(attrStr, "a"),
Some(new Predicate("ENDS_WITH", Array(FieldReference(strColName),
LiteralValue(UTF8String.fromString("a"), StringType)))))

testTranslateFilter(Contains(attrStr, "a"),
Some(new Predicate("CONTAINS", Array(FieldReference(strColName),
LiteralValue(UTF8String.fromString("a"), StringType)))))
}
}

test("translate complex expression") {
attrInts.foreach { case (attrInt, intColName) =>

// ABS(cint) - 2 <= 1
testTranslateFilter(LessThanOrEqual(
// Expressions are not supported
// Functions such as 'Abs' are not supported
Subtract(Abs(attrInt), 2), 1), None)

// (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
testTranslateFilter(Or(
And(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)
),
And(
GreaterThan(attrInt, 50),
LessThan(attrInt, 100))),
Some(new V2Or(
new V2And(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
new V2And(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(50, IntegerType))),
new Predicate("<", Array(FieldReference(intColName),
LiteralValue(100, IntegerType)))))
)
)

// (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
testTranslateFilter(Or(
And(
GreaterThan(attrInt, 1),
// Functions such as 'Abs' are not supported
LessThan(Abs(attrInt), 10)
),
And(
GreaterThan(attrInt, 50),
LessThan(attrInt, 100))), None)

// NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
testTranslateFilter(Not(And(
Or(
LessThanOrEqual(attrInt, 1),
// Functions such as 'Abs' are not supported
GreaterThanOrEqual(Abs(attrInt), 10)
),
Or(
LessThanOrEqual(attrInt, 50),
GreaterThanOrEqual(attrInt, 100)))), None)

// (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
testTranslateFilter(Or(
Or(
EqualTo(attrInt, 1),
EqualTo(attrInt, 10)
),
Or(
GreaterThan(attrInt, 0),
LessThan(attrInt, -10))),
Some(new V2Or(
new V2Or(
new Predicate("=", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
new Predicate("=", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
new V2Or(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(0, IntegerType))),
new Predicate("<", Array(FieldReference(intColName), LiteralValue(-10, IntegerType)))))
)
)

// (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
testTranslateFilter(Or(
Or(
EqualTo(attrInt, 1),
// Functions such as 'Abs' are not supported
EqualTo(Abs(attrInt), 10)
),
Or(
GreaterThan(attrInt, 0),
LessThan(attrInt, -10))), None)

// In end-to-end testing, conjunctive predicate should has been split
// before reaching DataSourceStrategy.translateFilter.
// This is for UT purpose to test each [[case]].
// (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
testTranslateFilter(And(
And(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)
),
And(
EqualTo(attrInt, 6),
IsNotNull(attrInt))),
Some(new V2And(
new V2And(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
new V2And(
new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))),
new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
)
)

// (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
testTranslateFilter(And(
And(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)
),
And(
// Functions such as 'Abs' are not supported
EqualTo(Abs(attrInt), 6),
IsNotNull(attrInt))), None)

// (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
testTranslateFilter(And(
Or(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)
),
Or(
EqualTo(attrInt, 6),
IsNotNull(attrInt))),
Some(new V2And(
new V2Or(
new Predicate(">", Array(FieldReference(intColName), LiteralValue(1, IntegerType))),
new Predicate("<", Array(FieldReference(intColName), LiteralValue(10, IntegerType)))),
new V2Or(
new Predicate("=", Array(FieldReference(intColName), LiteralValue(6, IntegerType))),
new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
)
)

// (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
testTranslateFilter(And(
Or(
GreaterThan(attrInt, 1),
LessThan(attrInt, 10)
),
Or(
// Functions such as 'Abs' are not supported
EqualTo(Abs(attrInt), 6),
IsNotNull(attrInt))), None)
}
}

Expand Down

0 comments on commit ff17bff

Please sign in to comment.