Skip to content

Commit

Permalink
catalyst module
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Apr 8, 2015
1 parent 8d2a36c commit 04ec7ac
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DistributionSuite extends FunSuite {
inputPartitioning: Partitioning,
requiredDistribution: Distribution,
satisfied: Boolean) {
if (inputPartitioning.satisfies(requiredDistribution) != satisfied)
if (inputPartitioning.satisfies(requiredDistribution) != satisfied) {
fail(
s"""
|== Input Partitioning ==
Expand All @@ -40,6 +40,7 @@ class DistributionSuite extends FunSuite {
|== Does input partitioning satisfy required distribution? ==
|Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)}
""".stripMargin)
}
}

test("HashPartitioning is the output partitioning") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

import scala.collection.immutable

class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
Expand All @@ -41,10 +43,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}


def caseSensitiveAnalyze(plan: LogicalPlan) =
def caseSensitiveAnalyze(plan: LogicalPlan): Unit =
caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan))

def caseInsensitiveAnalyze(plan: LogicalPlan) =
def caseInsensitiveAnalyze(plan: LogicalPlan): Unit =
caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan))

val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
Expand Down Expand Up @@ -147,7 +149,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
name: String,
plan: LogicalPlan,
errorMessages: Seq[String],
caseSensitive: Boolean = true) = {
caseSensitive: Boolean = true): Unit = {
test(name) {
val error = intercept[AnalysisException] {
if(caseSensitive) {
Expand Down Expand Up @@ -202,7 +204,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {

case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
override def output: Seq[Attribute] = Nil
}

errorTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(StringType, TimestampType, None)

// ComplexType
widenTest(NullType, MapType(IntegerType, StringType, false), Some(MapType(IntegerType, StringType, false)))
widenTest(NullType,
MapType(IntegerType, StringType, false),
Some(MapType(IntegerType, StringType, false)))
widenTest(NullType, StructType(Seq()), Some(StructType(Seq())))
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
Expand All @@ -113,7 +115,9 @@ class HiveTypeCoercionSuite extends PlanTest {
// Remove superflous boolean -> boolean casts.
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
// Stringify boolean when casting to string.
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
ruleTest(
Cast(Literal(false), StringType),
If(Literal(false), Literal("true"), Literal("false")))
}

test("coalesce casts") {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -176,40 +176,39 @@ class ConstantFoldingSuite extends PlanTest {
}

test("Constant folding test: expressions have null literals") {
val originalQuery =
testRelation
.select(
IsNull(Literal(null)) as 'c1,
IsNotNull(Literal(null)) as 'c2,
val originalQuery = testRelation.select(
IsNull(Literal(null)) as 'c1,
IsNotNull(Literal(null)) as 'c2,

GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4,
UnresolvedGetField(
Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,
GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(
Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4,
UnresolvedGetField(
Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

UnaryMinus(Literal.create(null, IntegerType)) as 'c6,
Cast(Literal(null), IntegerType) as 'c7,
Not(Literal.create(null, BooleanType)) as 'c8,
UnaryMinus(Literal.create(null, IntegerType)) as 'c6,
Cast(Literal(null), IntegerType) as 'c7,
Not(Literal.create(null, BooleanType)) as 'c8,

Add(Literal.create(null, IntegerType), 1) as 'c9,
Add(1, Literal.create(null, IntegerType)) as 'c10,
Add(Literal.create(null, IntegerType), 1) as 'c9,
Add(1, Literal.create(null, IntegerType)) as 'c10,

EqualTo(Literal.create(null, IntegerType), 1) as 'c11,
EqualTo(1, Literal.create(null, IntegerType)) as 'c12,
EqualTo(Literal.create(null, IntegerType), 1) as 'c11,
EqualTo(1, Literal.create(null, IntegerType)) as 'c12,

Like(Literal.create(null, StringType), "abc") as 'c13,
Like("abc", Literal.create(null, StringType)) as 'c14,
Like(Literal.create(null, StringType), "abc") as 'c13,
Like("abc", Literal.create(null, StringType)) as 'c14,

Upper(Literal.create(null, StringType)) as 'c15,
Upper(Literal.create(null, StringType)) as 'c15,

Substring(Literal.create(null, StringType), 0, 1) as 'c16,
Substring("abc", Literal.create(null, IntegerType), 1) as 'c17,
Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18,
Substring(Literal.create(null, StringType), 0, 1) as 'c16,
Substring("abc", Literal.create(null, IntegerType), 1) as 'c17,
Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18,

Contains(Literal.create(null, StringType), "abc") as 'c19,
Contains("abc", Literal.create(null, StringType)) as 'c20
)
Contains(Literal.create(null, StringType), "abc") as 'c19,
Contains("abc", Literal.create(null, StringType)) as 'c20
)

val optimized = Optimize(originalQuery.analyze)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ class FilterPushdownSuite extends PlanTest {

val originalQuery = {
z.join(x.join(y))
.where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr))
.where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) &&
("z.a".attr >= 3) && ("z.a".attr === "x.b".attr))
}

val optimized = Optimize(originalQuery.analyze)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class OptimizeInSuite extends PlanTest {
val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2))
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
.analyze

comparePlans(optimized, correctAnswer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ class PlanTest extends FunSuite {
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
val normalized1 = normalizeExprIds(plan1)
val normalized2 = normalizeExprIds(plan2)
if (normalized1 != normalized2)
if (normalized1 != normalized2) {
fail(
s"""
|== FAIL: Plans do not match ===
|${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
""".stripMargin)
}
}

/** Fails the test if the two expressions do not match */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SameResultSuite extends FunSuite {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = {
def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
val aAnalyzed = a.analyze
val bAnalyzed = b.analyze

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StringType, NullType}

case class Dummy(optKey: Option[Expression]) extends Expression {
def children = optKey.toSeq
def nullable = true
def dataType = NullType
def children: Seq[Expression] = optKey.toSeq
def nullable: Boolean = true
def dataType: NullType = NullType
override lazy val resolved = true
type EvaluatedType = Any
def eval(input: Row) = null.asInstanceOf[Any]
def eval(input: Row): Any = null.asInstanceOf[Any]
}

class TreeNodeSuite extends FunSuite {
Expand Down

0 comments on commit 04ec7ac

Please sign in to comment.