Skip to content

Commit

Permalink
Add more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 23, 2016
1 parent c8fb736 commit 81c46c7
Showing 1 changed file with 58 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.{DoubleType, LongType}

class ConstraintPropagationSuite extends SparkFunSuite {

Expand Down Expand Up @@ -234,4 +234,61 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
}

test("infer isnotnull constraints from compound expressions") {
val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
verifyConstraints(
tr.where('a.attr + 'b.attr === 'c.attr &&
IsNotNull(
Cast(
Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
Cast(resolveColumn(tr, "c"), LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "e")),
IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))

verifyConstraints(
tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
ExpressionSet(Seq(
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))

verifyConstraints(
tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
ExpressionSet(Seq(
(Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
(Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
Cast(resolveColumn(tr, "e") * 1000, LongType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "d")),
IsNotNull(resolveColumn(tr, "e")))))
}
}

0 comments on commit 81c46c7

Please sign in to comment.