Skip to content

Commit

Permalink
[WIP][SPARK-32976][SQL]Support column list in INSERT statement
Browse files Browse the repository at this point in the history
  • Loading branch information
yaooqinn committed Sep 28, 2020
1 parent d7aa3b5 commit 99056c4
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ query
;

insertInto
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable
: INSERT OVERWRITE TABLE? multipartIdentifier identifierList? (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier identifierList? partitionSpec? (IF NOT EXISTS)? #insertIntoTable
| INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
Expand Down Expand Up @@ -848,7 +848,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(ident, _) =>
lookupTempView(ident).getOrElse(u)
case i @ InsertIntoStatement(UnresolvedRelation(ident, _), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedRelation(ident, _), _, _, _, _, _) =>
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
Expand Down Expand Up @@ -911,7 +911,7 @@ class Analyzer(
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)

case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier, u.options)
.map(v2Relation => i.copy(table = v2Relation))
.getOrElse(i)
Expand Down Expand Up @@ -974,7 +974,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(table, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
val relation = table match {
case u: UnresolvedRelation =>
lookupRelation(u.multipartIdentifier, u.options).getOrElse(u)
Expand Down Expand Up @@ -1045,7 +1045,7 @@ class Analyzer(

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
Expand All @@ -1065,6 +1065,21 @@ class Analyzer(
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
}

case i @ InsertIntoStatement(table, _, _, _, _, _) if table.resolved =>
val resolved = i.columns.map {
case u: UnresolvedAttribute => withPosition(u) {
table.resolve(u.nameParts, resolver).map(_.toAttribute)
.getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}"))
}
case other => other
}
val newOutput = if (resolved.isEmpty) {
table.output
} else {
resolved
}
i.copy(columns = newOutput)
}

private def partitionColumnNames(table: Table): Seq[String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ trait CheckAnalysis extends PredicateHelper {
case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")

case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) =>
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")

case u: UnresolvedV2Relation if isView(u.originalNameParts) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ package object dsl {
partition: Map[String, Option[String]] = Map.empty,
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table, partition, Nil, logicalPlan, overwrite, ifPartitionNotExists)

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Parameters used for writing query to a table:
* (multipartIdentifier, partitionKeys, ifPartitionNotExists).
*/
type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean)
type InsertTableParams = (Seq[String], Seq[Attribute], Map[String, Option[String]], Boolean)

/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
Expand All @@ -269,18 +269,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
ctx match {
case table: InsertIntoTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists)
case table: InsertOverwriteTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = true,
ifPartitionNotExists)
Expand All @@ -301,13 +303,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList())
.map(visitIdentifierList)
.getOrElse(Nil).map(UnresolvedAttribute(_))
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

if (ctx.EXISTS != null) {
operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
}

(tableIdent, partitionKeys, false)
(tableIdent, cols, partitionKeys, false)
}

/**
Expand All @@ -317,6 +322,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList())
.map(visitIdentifierList).getOrElse(Nil).map(UnresolvedAttribute(_))
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
Expand All @@ -325,7 +332,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
dynamicPartitionKeys.keys.mkString(", "), ctx)
}

(tableIdent, partitionKeys, ctx.EXISTS() != null)
(tableIdent, cols, partitionKeys, ctx.EXISTS() != null)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.ViewType
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -324,6 +324,7 @@ case class DescribeColumnStatement(
case class InsertIntoStatement(
table: LogicalPlan,
partitionSpec: Map[String, Option[String]],
columns: Seq[Attribute],
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean) extends ParsedStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -847,6 +848,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("testcat2", "db", "tbl"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -861,6 +863,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -874,6 +877,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -889,6 +893,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -903,6 +908,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3")),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = true))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class PlanParserSuite extends AnalysisTest {
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table("s"), partition, plan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table("s"), partition, Nil, plan, overwrite, ifPartitionNotExists)

// Single inserts
assertEqual(s"insert overwrite table s $sql",
Expand Down Expand Up @@ -713,7 +713,7 @@ class PlanParserSuite extends AnalysisTest {
comparePlans(
parsePlan(
"INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"),
InsertIntoStatement(table("s"), Map.empty,
InsertIntoStatement(table("s"), Map.empty, Nil,
UnresolvedHint("REPARTITION", Seq(Literal(100)),
UnresolvedHint("COALESCE", Seq(Literal(500)),
UnresolvedHint("COALESCE", Seq(Literal(10)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateTableAsSelectStatement, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table}
Expand Down Expand Up @@ -536,6 +536,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
InsertIntoStatement(
table = UnresolvedRelation(tableIdent),
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name))

case InsertIntoStatement(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
parts, _, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)

case InsertIntoDir(_, storage, provider, query, overwrite)
Expand All @@ -166,7 +166,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
InsertIntoDataSourceDirCommand(storage, provider.get, query, overwrite)

case i @ InsertIntoStatement(
l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, query, overwrite, _) =>
l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) =>
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
Expand Down Expand Up @@ -261,11 +261,11 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, options), _, _, _, _)
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, options), _, _, _, _, _)
if DDLUtils.isDatasourceTable(tableMeta) =>
i.copy(table = readDataSourceTable(tableMeta, options))

case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, _), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedCatalogRelation(tableMeta, _), _, _, _, _, _) =>
i.copy(table = DDLUtils.readHiveTable(tableMeta))

case UnresolvedCatalogRelation(tableMeta, options) if DDLUtils.isDatasourceTable(tableMeta) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, File
*/
class FallBackFileSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @
InsertIntoStatement(d @ DataSourceV2Relation(table: FileTable, _, _, _, _), _, _, _, _) =>
case i @ InsertIntoStatement(
d @ DataSourceV2Relation(table: FileTable, _, _, _, _), _, _, _, _, _) =>
val v1FileFormat = table.fallbackFileFormat.newInstance()
val relation = HadoopFsRelation(
table.fileIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,17 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
insert.partitionSpec, partColNames, tblName, conf.resolver)

val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet
val expectedColumns = insert.table.output.filterNot(a => staticPartCols.contains(a.name))

val expectedColumns = insert.columns.filterNot(a => staticPartCols.contains(a.name))

if (expectedColumns.length !=
insert.table.output.filterNot(a => staticPartCols.contains(a.name)).length) {
throw new AnalysisException(
s"$tblName requires that the data to be inserted have the same number of columns as the " +
s"target table: target table has ${insert.table.output.size} column(s) but the " +
s"specified part has only ${expectedColumns.length} column(s), " +
s"and ${staticPartCols.size} partition column(s) having constant value(s).")
}
if (expectedColumns.length != insert.query.schema.length) {
throw new AnalysisException(
s"$tblName requires that the data to be inserted have the same number of columns as the " +
Expand Down Expand Up @@ -436,7 +445,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(table, _, query, _, _) if table.resolved && query.resolved =>
case i @ InsertIntoStatement(table, _, _, query, _, _) if table.resolved && query.resolved =>
table match {
case relation: HiveTableRelation =>
val metadata = relation.tableMeta
Expand Down Expand Up @@ -514,7 +523,7 @@ object PreWriteCheck extends (LogicalPlan => Unit) {

def apply(plan: LogicalPlan): Unit = {
plan.foreach {
case InsertIntoStatement(l @ LogicalRelation(relation, _, _, _), partition, query, _, _) =>
case InsertIntoStatement(l @ LogicalRelation(relation, _, _, _), partition, _, query, _, _) =>
// Get all input data source relations of the query.
val srcRelations = query.collect {
case LogicalRelation(src, _, _, _) => src
Expand All @@ -536,7 +545,7 @@ object PreWriteCheck extends (LogicalPlan => Unit) {
case _ => failAnalysis(s"$relation does not allow insertion.")
}

case InsertIntoStatement(t, _, _, _, _)
case InsertIntoStatement(t, _, _, _, _, _)
if !t.isInstanceOf[LeafNode] ||
t.isInstanceOf[Range] ||
t.isInstanceOf[OneRowRelation] ||
Expand Down
Loading

0 comments on commit 99056c4

Please sign in to comment.