Skip to content

Commit

Permalink
Fix tests round 2
Browse files Browse the repository at this point in the history
There were some issues with case sensitivity analysis and error
messages not being exactly as expected. The latter is now relaxed
where possible.
  • Loading branch information
Andrew Or committed Mar 18, 2016
1 parent 78cbcbd commit 5e16480
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ import org.apache.spark.sql.types._
* Used for testing when all relations are already filled in and the analyzer needs only
* to resolve attribute references.
*/
object SimpleAnalyzer
extends Analyzer(
new SessionCatalog(new InMemoryCatalog),
EmptyFunctionRegistry,
new SimpleCatalystConf(true))
object SimpleAnalyzer extends SimpleAnalyzer(new SimpleCatalystConf(true))
class SimpleAnalyzer(conf: CatalystConf)
extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
* a [[FunctionRegistry]].
* [[UnresolvedRelation]]s into fully typed objects using information in a
* [[SessionCatalog]] and a [[FunctionRegistry]].
*/
class Analyzer(
catalog: SessionCatalog,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,22 @@ class InMemoryCatalog extends ExternalCatalog {

private def requireFunctionExists(db: String, funcName: String): Unit = {
if (!functionExists(db, funcName)) {
throw new AnalysisException(s"Function '$funcName' does not exist in database '$db'")
throw new AnalysisException(
s"Function not found: '$funcName' does not exist in database '$db'")
}
}

private def requireTableExists(db: String, table: String): Unit = {
if (!tableExists(db, table)) {
throw new AnalysisException(s"Table '$table' does not exist in database '$db'")
throw new AnalysisException(
s"Table not found: '$table' does not exist in database '$db'")
}
}

private def requirePartitionExists(db: String, table: String, spec: TablePartitionSpec): Unit = {
if (!partitionExists(db, table, spec)) {
throw new AnalysisException(
s"Partition does not exist in database '$db' table '$table': '$spec'")
s"Partition not found: database '$db' table '$table' does not contain: '$spec'")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}

Expand All @@ -31,9 +32,13 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
* tables and functions of the Spark Session that it belongs to.
*/
class SessionCatalog(externalCatalog: ExternalCatalog, caseSensitiveAnalysis: Boolean = true) {
class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
import ExternalCatalog._

def this(externalCatalog: ExternalCatalog) {
this(externalCatalog, new SimpleCatalystConf(true))
}

protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]
protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]

Expand All @@ -53,7 +58,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, caseSensitiveAnalysis: Bo
* Format table name, taking into account case sensitivity.
*/
protected[this] def formatTableName(name: String): String = {
if (caseSensitiveAnalysis) name else name.toLowerCase
if (conf.caseSensitiveAnalysis) name else name.toLowerCase
}

// ----------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ trait AnalysisTest extends PlanTest {

private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SimpleCatalystConf(caseSensitive)
val catalog = new SessionCatalog(new InMemoryCatalog, caseSensitive)
val catalog = new SessionCatalog(new InMemoryCatalog, conf)
catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true)
new Analyzer(catalog, EmptyFunctionRegistry, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._

class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val conf = new SimpleCatalystConf(true)
private val catalog = new SessionCatalog(new InMemoryCatalog, caseSensitiveAnalysis = true)
private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)

private val relation = LocalRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ class SessionCatalogSuite extends SparkFunSuite {
TableIdentifier("tbl4"),
TableIdentifier("tbl1", Some("db2")),
TableIdentifier("tbl2", Some("db2"))))
intercept[AnalysisException] {
catalog.listTables("unknown_db")
}
assert(catalog.listTables("unknown_db").toSet ==
Set(TableIdentifier("tbl1"), TableIdentifier("tbl4")))
}

test("list tables with pattern") {
Expand All @@ -446,9 +445,8 @@ class SessionCatalogSuite extends SparkFunSuite {
TableIdentifier("tbl2", Some("db2"))))
assert(catalog.listTables("db2", "*1").toSet ==
Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2"))))
intercept[AnalysisException] {
catalog.listTables("unknown_db")
}
assert(catalog.listTables("unknown_db", "*").toSet ==
Set(TableIdentifier("tbl1"), TableIdentifier("tbl4")))
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ case class CreateTempTableUsing(
provider: String,
options: Map[String, String]) extends RunnableCommand {

if (tableIdent.database.isDefined) {
throw new AnalysisException(
s"Temporary table '$tableIdent' should not have specified a database")
}

def run(sqlContext: SQLContext): Seq[Row] = {
val dataSource = DataSource(
sqlContext,
Expand All @@ -116,6 +121,11 @@ case class CreateTempTableUsingAsSelect(
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {

if (tableIdent.database.isDefined) {
throw new AnalysisException(
s"Temporary table '$tableIdent' should not have specified a database")
}

override def run(sqlContext: SQLContext): Seq[Row] = {
val df = Dataset.newDataFrame(sqlContext, query)
val dataSource = DataSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Internal catalog for managing table and database states.
*/
lazy val sessionCatalog = new SessionCatalog(ctx.externalCatalog)
lazy val sessionCatalog = new SessionCatalog(ctx.externalCatalog, conf)

/**
* Internal catalog for managing functions registered by the user.
Expand Down
22 changes: 13 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1395,12 +1395,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("SPARK-4699 case sensitivity SQL query") {
sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("testTable1")
checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
sqlContext.setConf(SQLConf.CASE_SENSITIVE, true)
val orig = sqlContext.getConf(SQLConf.CASE_SENSITIVE)
try {
sqlContext.setConf(SQLConf.CASE_SENSITIVE, false)
val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
rdd.toDF().registerTempTable("testTable1")
checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1"))
} finally {
sqlContext.setConf(SQLConf.CASE_SENSITIVE, orig)
}
}

test("SPARK-6145: ORDER BY test for nested fields") {
Expand Down Expand Up @@ -1674,7 +1678,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
.format("parquet")
.save(path)

val message = intercept[AnalysisException] {
// We don't support creating a temporary table while specifying a database
intercept[AnalysisException] {
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE db.t
Expand All @@ -1684,9 +1689,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|)
""".stripMargin)
}.getMessage
assert(message.contains("Specifying database name or other qualifiers are not allowed"))

// If you use backticks to quote the name of a temporary table having dot in it.
// If you use backticks to quote the name then it's OK.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,16 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType


class HiveSessionCatalog(
externalCatalog: HiveCatalog,
client: HiveClient,
context: HiveContext,
caseSensitiveAnalysis: Boolean)
extends SessionCatalog(externalCatalog, caseSensitiveAnalysis) {
conf: SQLConf)
extends SessionCatalog(externalCatalog, conf) {

override def setCurrentDatabase(db: String): Unit = {
super.setCurrentDatabase(db)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
* Internal catalog for managing table and database states.
*/
override lazy val sessionCatalog = {
new HiveSessionCatalog(
ctx.hiveCatalog, ctx.metadataHive, ctx, caseSensitiveAnalysis = false)
new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, conf)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
.format("parquet")
.save(path)

// We don't support creating a temporary table while specifying a database
val message = intercept[AnalysisException] {
sqlContext.sql(
s"""
Expand All @@ -1331,9 +1332,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
|)
""".stripMargin)
}.getMessage
assert(message.contains("Specifying database name or other qualifiers are not allowed"))

// If you use backticks to quote the name of a temporary table having dot in it.
// If you use backticks to quote the name then it's OK.
sqlContext.sql(
s"""
|CREATE TEMPORARY TABLE `db.t`
Expand Down

0 comments on commit 5e16480

Please sign in to comment.