Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
imback82 committed Mar 14, 2021
1 parent 6813754 commit f1c86b8
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable

import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.errors.QueryCompilationErrors

Expand All @@ -40,12 +39,12 @@ class GlobalTempViewManager(val database: String) {

/** List of view definitions, mapping from view name to logical plan. */
@GuardedBy("this")
private val viewDefinitions = new mutable.HashMap[String, LogicalPlan]
private val viewDefinitions = new mutable.HashMap[String, TemporaryViewRelation]

/**
* Returns the global view definition which matches the given name, or None if not found.
*/
def get(name: String): Option[LogicalPlan] = synchronized {
def get(name: String): Option[TemporaryViewRelation] = synchronized {
viewDefinitions.get(name)
}

Expand All @@ -55,7 +54,7 @@ class GlobalTempViewManager(val database: String) {
*/
def create(
name: String,
viewDefinition: LogicalPlan,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = synchronized {
if (!overrideIfExists && viewDefinitions.contains(name)) {
throw new TempTableAlreadyExistsException(name)
Expand All @@ -68,7 +67,7 @@ class GlobalTempViewManager(val database: String) {
*/
def update(
name: String,
viewDefinition: LogicalPlan): Boolean = synchronized {
viewDefinition: TemporaryViewRelation): Boolean = synchronized {
if (viewDefinitions.contains(name)) {
viewDefinitions.put(name, viewDefinition)
true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SessionCatalog(

/** List of temporary views, mapping from table name to their logical plan. */
@GuardedBy("this")
protected val tempViews = new mutable.HashMap[String, LogicalPlan]
protected val tempViews = new mutable.HashMap[String, TemporaryViewRelation]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
Expand Down Expand Up @@ -573,21 +573,21 @@ class SessionCatalog(
*/
def createTempView(
name: String,
tableDefinition: LogicalPlan,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = synchronized {
val table = formatTableName(name)
if (tempViews.contains(table) && !overrideIfExists) {
throw new TempTableAlreadyExistsException(name)
}
tempViews.put(table, tableDefinition)
tempViews.put(table, viewDefinition)
}

/**
* Create a global temporary view.
*/
def createGlobalTempView(
name: String,
viewDefinition: LogicalPlan,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = {
globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists)
}
Expand All @@ -598,7 +598,7 @@ class SessionCatalog(
*/
def alterTempViewDefinition(
name: TableIdentifier,
viewDefinition: LogicalPlan): Boolean = synchronized {
viewDefinition: TemporaryViewRelation): Boolean = synchronized {
val viewName = formatTableName(name.table)
if (name.database.isEmpty) {
if (tempViews.contains(viewName)) {
Expand All @@ -617,14 +617,14 @@ class SessionCatalog(
/**
* Return a local temporary view exactly as it was stored.
*/
def getRawTempView(name: String): Option[LogicalPlan] = synchronized {
def getRawTempView(name: String): Option[TemporaryViewRelation] = synchronized {
tempViews.get(formatTableName(name))
}

/**
* Generate a [[View]] operator from the temporary view stored.
*/
def getTempView(name: String): Option[LogicalPlan] = synchronized {
def getTempView(name: String): Option[View] = synchronized {
getRawTempView(name).map(getTempViewPlan)
}

Expand All @@ -635,14 +635,14 @@ class SessionCatalog(
/**
* Return a global temporary view exactly as it was stored.
*/
def getRawGlobalTempView(name: String): Option[LogicalPlan] = {
def getRawGlobalTempView(name: String): Option[TemporaryViewRelation] = {
globalTempViewManager.get(formatTableName(name))
}

/**
* Generate a [[View]] operator from the global temporary view stored.
*/
def getGlobalTempView(name: String): Option[LogicalPlan] = {
def getGlobalTempView(name: String): Option[View] = {
getRawGlobalTempView(name).map(getTempViewPlan)
}

Expand Down Expand Up @@ -680,25 +680,10 @@ class SessionCatalog(
def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized {
val table = formatTableName(name.table)
if (name.database.isEmpty) {
tempViews.get(table).map {
case TemporaryViewRelation(metadata, _) => metadata
case plan =>
CatalogTable(
identifier = TableIdentifier(table),
tableType = CatalogTableType.VIEW,
storage = CatalogStorageFormat.empty,
schema = plan.output.toStructType)
}.getOrElse(getTableMetadata(name))
tempViews.get(table).map(_.tableMeta).getOrElse(getTableMetadata(name))
} else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) {
globalTempViewManager.get(table).map {
case TemporaryViewRelation(metadata, _) => metadata
case plan =>
CatalogTable(
identifier = TableIdentifier(table, Some(globalTempViewManager.database)),
tableType = CatalogTableType.VIEW,
storage = CatalogStorageFormat.empty,
schema = plan.output.toStructType)
}.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table))
globalTempViewManager.get(table).map(_.tableMeta)
.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table))
} else {
getTableMetadata(name)
}
Expand Down Expand Up @@ -834,20 +819,11 @@ class SessionCatalog(
}
}

private def getTempViewPlan(plan: LogicalPlan): LogicalPlan = {
plan match {
case TemporaryViewRelation(tableMeta, None) =>
fromCatalogTable(tableMeta, isTempView = true)
case TemporaryViewRelation(tableMeta, Some(plan)) =>
View(desc = tableMeta, isTempView = true, child = plan)
case other => other
}
}

def getTempViewSchema(plan: LogicalPlan): StructType = {
plan match {
case viewInfo: TemporaryViewRelation => viewInfo.tableMeta.schema
case v => v.schema
private def getTempViewPlan(viewInfo: TemporaryViewRelation): View = {
if (viewInfo.plan.isEmpty) {
fromCatalogTable(viewInfo.tableMeta, isTempView = true)
} else {
View(desc = viewInfo.tableMeta, isTempView = true, child = viewInfo.plan.get)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,62 @@ import java.net.URI
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.{QueryPlanningTracker, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog, TemporaryViewRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}

trait AnalysisTest extends PlanTest {

protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil

protected def createTempView(
catalog: SessionCatalog,
name: String,
plan: LogicalPlan,
overrideIfExists: Boolean): Unit = {
val identifier = TableIdentifier(name)
val metadata = CatalogTable(
identifier = identifier,
tableType = CatalogTableType.VIEW,
storage = CatalogStorageFormat.empty,
schema = plan.schema,
properties = Map((VIEW_STORING_ANALYZED_PLAN, "true")))
val viewDefinition = TemporaryViewRelation(metadata, Some(plan))
catalog.createTempView(name, viewDefinition, overrideIfExists)
}

protected def createGlobalTempView(
catalog: SessionCatalog,
name: String,
plan: LogicalPlan,
overrideIfExists: Boolean): Unit = {
val globalDb = Some(SQLConf.get.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE))
val identifier = TableIdentifier(name, globalDb)
val metadata = CatalogTable(
identifier = identifier,
tableType = CatalogTableType.VIEW,
storage = CatalogStorageFormat.empty,
schema = plan.schema,
properties = Map((VIEW_STORING_ANALYZED_PLAN, "true")))
val viewDefinition = TemporaryViewRelation(metadata, Some(plan))
catalog.createGlobalTempView(name, viewDefinition, overrideIfExists)
}

protected def getAnalyzer: Analyzer = {
val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin)
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
ignoreIfExists = false)
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true)
catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true)
createTempView(catalog, "TaBlE", TestRelations.testRelation, overrideIfExists = true)
createTempView(catalog, "TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
createTempView(catalog, "TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
createGlobalTempView(catalog, "TaBlE4", TestRelations.testRelation4, overrideIfExists = true)
createGlobalTempView(catalog, "TaBlE5", TestRelations.testRelation5, overrideIfExists = true)
new Analyzer(catalog) {
override val extendedResolutionRules = EliminateSubqueryAliases +: extendedAnalysisRules
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Unio
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry)
private val analyzer = new Analyzer(catalog)
Expand All @@ -49,10 +48,6 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
private val f: Expression = UnresolvedAttribute("f")
private val b: Expression = UnresolvedAttribute("b")

before {
catalog.createTempView("table", relation, overrideIfExists = true)
}

private def checkType(expression: Expression, expectedType: DataType): Unit = {
val plan = Project(Seq(Alias(expression, "c")()), relation)
assert(analyzer.execute(plan).schema.fields(0).dataType === expectedType)
Expand Down
Loading

0 comments on commit f1c86b8

Please sign in to comment.