Skip to content

Commit

Permalink
[SPARK-5213] [SQL] Pluggable SQL Parser Support
Browse files Browse the repository at this point in the history
This PR aims to make the SQL Parser Pluggable, and user can register it's own parser via Spark SQL CLI.

```
# add the jar into the classpath
$hchengmydesktop:spark>bin/spark-sql --jars sql99.jar

-- switch to "hiveql" dialect
   spark-sql>SET spark.sql.dialect=hiveql;
   spark-sql>SELECT * FROM src LIMIT 1;

-- switch to "sql" dialect
   spark-sql>SET spark.sql.dialect=sql;
   spark-sql>SELECT * FROM src LIMIT 1;

-- switch to a custom dialect
   spark-sql>SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
   spark-sql>SELECT * FROM src LIMIT 1;

-- register the non-exist SQL dialect
   spark-sql> SET spark.sql.dialect=NotExistedClass;
   spark-sql> SELECT * FROM src LIMIT 1;
-- Exception will be thrown and switch to default sql dialect ("sql" for SQLContext and "hiveql" for HiveContext)
```

Author: Cheng Hao <[email protected]>

Closes apache#4015 from chenghao-intel/sqlparser and squashes the following commits:

493775c [Cheng Hao] update the code as feedback
81a731f [Cheng Hao] remove the unecessary comment
aab0b0b [Cheng Hao] polish the code a little bit
49b9d81 [Cheng Hao] shrink the comment for rebasing
  • Loading branch information
chenghao-intel authored and marmbrus committed May 1, 2015
1 parent e991255 commit 3ba5aaa
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ import scala.util.parsing.input.CharArrayReader.EofCh

import org.apache.spark.sql.catalyst.plans.logical._

private[sql] object KeywordNormalizer {
def apply(str: String): String = str.toLowerCase()
}

private[sql] abstract class AbstractSparkSQLParser
extends StandardTokenParsers with PackratParsers {

Expand All @@ -42,7 +38,7 @@ private[sql] abstract class AbstractSparkSQLParser
}

protected case class Keyword(str: String) {
def normalize: String = KeywordNormalizer(str)
def normalize: String = lexical.normalizeKeyword(str)
def parser: Parser[String] = normalize
}

Expand Down Expand Up @@ -90,13 +86,16 @@ class SqlLexical extends StdLexical {
reserved ++= keywords
}

/* Normal the keyword string */
def normalizeKeyword(str: String): String = str.toLowerCase

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
)

protected override def processIdent(name: String) = {
val token = KeywordNormalizer(name)
val token = normalizeKeyword(name)
if (reserved contains token) Keyword(token) else Identifier(name)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* Root class of SQL Parser Dialect, and we don't guarantee the binary
* compatibility for the future release, let's keep it as the internal
* interface for advanced user.
*
*/
@DeveloperApi
abstract class Dialect {
// this is the main function that will be implemented by sql parser.
def parse(sqlText: String): LogicalPlan
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ package object errors {
}
}

class DialectException(msg: String, cause: Throwable) extends Exception(msg, cause)

/**
* Wraps any exceptions that are thrown while executing `f` in a
* [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`.
Expand Down
82 changes: 68 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._
import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import com.google.common.reflect.TypeToken

Expand All @@ -32,9 +33,11 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.Dialect
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions}
import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
Expand All @@ -44,6 +47,45 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}

/**
* Currently we support the default dialect named "sql", associated with the class
* [[DefaultDialect]]
*
* And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
* {{{
*-- switch to "hiveql" dialect
* spark-sql>SET spark.sql.dialect=hiveql;
* spark-sql>SELECT * FROM src LIMIT 1;
*
*-- switch to "sql" dialect
* spark-sql>SET spark.sql.dialect=sql;
* spark-sql>SELECT * FROM src LIMIT 1;
*
*-- register the new SQL dialect
* spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
* spark-sql> SELECT * FROM src LIMIT 1;
*
*-- register the non-exist SQL dialect
* spark-sql> SET spark.sql.dialect=NotExistedClass;
* spark-sql> SELECT * FROM src LIMIT 1;
*
*-- Exception will be thrown and switch to dialect
*-- "sql" (for SQLContext) or
*-- "hiveql" (for HiveContext)
* }}}
*/
private[spark] class DefaultDialect extends Dialect {
@transient
protected val sqlParser = {
val catalystSqlParser = new catalyst.SqlParser
new SparkSQLParser(catalystSqlParser.parse)
}

override def parse(sqlText: String): LogicalPlan = {
sqlParser.parse(sqlText)
}
}

/**
* The entry point for working with structured data (rows and columns) in Spark. Allows the
* creation of [[DataFrame]] objects as well as the execution of SQL queries.
Expand Down Expand Up @@ -132,17 +174,27 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer

@transient
protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))

@transient
protected[sql] val sqlParser = {
val fallback = new catalyst.SqlParser
new SparkSQLParser(fallback.parse(_))
protected[sql] val ddlParser = new DDLParser((sql: String) => { getSQLDialect().parse(sql) })

protected[sql] def getSQLDialect(): Dialect = {
try {
val clazz = Utils.classForName(dialectClassName)
clazz.newInstance().asInstanceOf[Dialect]
} catch {
case NonFatal(e) =>
// Since we didn't find the available SQL Dialect, it will fail even for SET command:
// SET spark.sql.dialect=sql; Let's reset as default dialect automatically.
val dialect = conf.dialect
// reset the sql dialect
conf.unsetConf(SQLConf.DIALECT)
// throw out the exception, and the default sql dialect will take effect for next query.
throw new DialectException(
s"""Instantiating dialect '$dialect' failed.
|Reverting to default dialect '${conf.dialect}'""".stripMargin, e)
}
}

protected[sql] def parseSql(sql: String): LogicalPlan = {
ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql))
}
protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false)

protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))

Expand All @@ -156,6 +208,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val defaultSession = createSession()

protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultDialect].getCanonicalName
} else {
conf.dialect
}

sparkContext.getConf.getAll.foreach {
case (key, value) if key.startsWith("spark.sql") => setConf(key, value)
case _ =>
Expand Down Expand Up @@ -945,11 +1003,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group basic
*/
def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
DataFrame(this, parseSql(sqlText))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}")
}
DataFrame(this, parseSql(sqlText))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ private[sql] class DDLParser(
parseQuery: String => LogicalPlan)
extends AbstractSparkSQLParser with DataTypeParser with Logging {

def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
try {
Some(parse(input))
parse(input)
} catch {
case ddlException: DDLException => throw ddlException
case _ if !exceptionOnError => None
case _ if !exceptionOnError => parseQuery(input)
case x: Throwable => throw x
}
}
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}

import org.apache.spark.sql.types._

/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultDialect

class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
Expand Down Expand Up @@ -64,6 +69,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
}

test("SQL Dialect Switching to a new SQL parser") {
val newContext = new SQLContext(TestSQLContext.sparkContext)
newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(newContext.getSQLDialect().getClass === classOf[MyDialect])
assert(newContext.sql("SELECT 1").collect() === Array(Row(1)))
}

test("SQL Dialect Switch to an invalid parser with alias") {
val newContext = new SQLContext(TestSQLContext.sparkContext)
newContext.sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
newContext.sql("SELECT 1")
}
// test if the dialect set back to DefaultSQLDialect
assert(newContext.getSQLDialect().getClass === classOf[DefaultDialect])
}

test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
sql("SELECT a FROM testData2 SORT BY a"),
Expand Down
41 changes: 25 additions & 16 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive
import java.io.{BufferedReader, InputStreamReader, PrintStream}
import java.sql.Timestamp

import org.apache.hadoop.hive.ql.parse.VariableSubstitution
import org.apache.spark.sql.catalyst.Dialect

import scala.collection.JavaConversions._
import scala.language.implicitConversions

Expand All @@ -42,6 +45,15 @@ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNative
import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy}
import org.apache.spark.sql.types._

/**
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
*/
private[hive] class HiveQLDialect extends Dialect {
override def parse(sqlText: String): LogicalPlan = {
HiveQl.parseSql(sqlText)
}
}

/**
* An instance of the Spark SQL execution engine that integrates with data stored in Hive.
* Configuration for Hive is read from hive-site.xml on the classpath.
Expand Down Expand Up @@ -81,25 +93,16 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected[sql] def convertCTAS: Boolean =
getConf("spark.sql.hive.convertCTAS", "false").toBoolean

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)

@transient
protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_))

override def sql(sqlText: String): DataFrame = {
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false)
DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
protected[sql] lazy val substitutor = new VariableSubstitution()

protected[sql] override def parseSql(sql: String): LogicalPlan = {
super.parseSql(substitutor.substitute(hiveconf, sql))
}

override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)

/**
* Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
Expand Down Expand Up @@ -356,6 +359,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
}

override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
classOf[HiveQLDialect].getCanonicalName
} else {
super.dialectClassName
}

@transient
private val hivePlanner = new SparkPlanner with HiveStrategies {
val hiveContext = self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
/** Fewer partitions to speed up testing. */
protected[sql] override lazy val conf: SQLConf = new SQLConf {
override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
override def dialect: String = getConf(SQLConf.DIALECT, "hiveql")

// TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
// The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
}
}

Expand Down
Loading

0 comments on commit 3ba5aaa

Please sign in to comment.