Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32976][SQL]Support column list in INSERT statement #29893

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ query
;

insertInto
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable
: INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable
| INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable
| INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
| INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -218,6 +218,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveTableValuedFunctions ::
ResolveNamespace(catalogManager) ::
new ResolveCatalogs(catalogManager) ::
ResolveUserSpecifiedColumns ::
ResolveInsertInto ::
ResolveRelations ::
ResolveTables ::
Expand Down Expand Up @@ -846,7 +847,7 @@ class Analyzer(override val catalogManager: CatalogManager)
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(ident, _, isStreaming) =>
lookupTempView(ident, isStreaming).getOrElse(u)
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _) =>
case i @ InsertIntoStatement(UnresolvedRelation(ident, _, false), _, _, _, _, _) =>
lookupTempView(ident)
.map(view => i.copy(table = view))
.getOrElse(i)
Expand Down Expand Up @@ -960,7 +961,7 @@ class Analyzer(override val catalogManager: CatalogManager)
.map(ResolvedTable(catalog.asTableCatalog, ident, _))
.getOrElse(u)

case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _)
case i @ InsertIntoStatement(u @ UnresolvedRelation(_, _, false), _, _, _, _, _)
if i.query.resolved =>
lookupV2Relation(u.multipartIdentifier, u.options, false)
.map(v2Relation => i.copy(table = v2Relation))
Expand Down Expand Up @@ -1042,7 +1043,7 @@ class Analyzer(override val catalogManager: CatalogManager)
}

def apply(plan: LogicalPlan): LogicalPlan = ResolveTempViews(plan).resolveOperatorsUp {
case i @ InsertIntoStatement(table, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved =>
val relation = table match {
case u @ UnresolvedRelation(_, _, false) =>
lookupRelation(u.multipartIdentifier, u.options, false).getOrElse(u)
Expand Down Expand Up @@ -1157,7 +1158,8 @@ class Analyzer(override val catalogManager: CatalogManager)

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved && i.userSpecifiedCols.isEmpty =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
Expand Down Expand Up @@ -3104,6 +3106,62 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}

object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case i: InsertIntoStatement if i.table.resolved && i.query.resolved &&
i.userSpecifiedCols.nonEmpty =>
val resolved = resolveUserSpecifiedColumns(i)
val projection = addColumnListOnQuery(i.table.output, resolved, i.query)
i.copy(userSpecifiedCols = Nil, query = projection)
}

private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = {
SchemaUtils.checkColumnNameDuplication(
i.userSpecifiedCols, "in the column list", resolver)

i.userSpecifiedCols.map { col =>
i.table.resolve(Seq(col), resolver)
.getOrElse(i.table.failAnalysis(s"Cannot resolve column name $col"))
}
}

private def addColumnListOnQuery(
tableOutput: Seq[Attribute],
cols: Seq[NamedExpression],
query: LogicalPlan): LogicalPlan = {
val errors = new mutable.ArrayBuffer[String]()

def failAdd(): Unit = {
val errMsg = if (errors.nonEmpty) errors.mkString("\n- ", "\n- ", "") else ""
query.failAnalysis(
s"""Cannot write to table due to mismatched user specified columns and data columns:
|Specified columns: ${cols.map(c => s"'${c.name}'").mkString(", ")}
|Data columns: ${query.output.map(c => s"'${c.name}'").mkString(", ")}$errMsg"""
.stripMargin)
}

if (cols.size != query.output.size) failAdd()

val nameToQueryExpr = cols.zip(query.output).toMap
val resolved = tableOutput.flatMap { tableAttr =>
if (nameToQueryExpr.contains(tableAttr)) {
TableOutputResolver.checkField(
tableAttr, nameToQueryExpr(tableAttr), byName = false, conf, err => errors += err)
} else {
None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when can we go to this branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for static partition columns

}
}

if (errors.nonEmpty) failAdd()

if (resolved == query.output) {
query
} else {
Project(resolved, query)
}
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ trait CheckAnalysis extends PredicateHelper {
case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")

case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) =>
case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) =>
failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}")

// TODO (SPARK-27484): handle streaming write commands when we have them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object TableOutputResolver {
}
}

private def checkField(
def checkField(
tableAttr: Attribute,
queryExpr: NamedExpression,
byName: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ package object dsl {
partition: Map[String, Option[String]] = Map.empty,
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table, partition, Nil, logicalPlan, overwrite, ifPartitionNotExists)

def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg

/**
* Parameters used for writing query to a table:
* (multipartIdentifier, partitionKeys, ifPartitionNotExists).
* (multipartIdentifier, tableColumnList, partitionKeys, ifPartitionNotExists).
*/
type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean)
type InsertTableParams = (Seq[String], Seq[String], Map[String, Option[String]], Boolean)

/**
* Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
Expand All @@ -255,8 +255,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
/**
* Add an
* {{{
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]?
* INSERT INTO [TABLE] tableIdentifier [partitionSpec]
* INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList]
* INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList]
* INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
* INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
* }}}
Expand All @@ -267,18 +267,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
ctx match {
case table: InsertIntoTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = false,
ifPartitionNotExists)
case table: InsertOverwriteTableContext =>
val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table)
InsertIntoStatement(
UnresolvedRelation(tableIdent),
partition,
cols,
query,
overwrite = true,
ifPartitionNotExists)
Expand All @@ -299,13 +301,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
override def visitInsertIntoTable(
ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

if (ctx.EXISTS != null) {
operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx)
}

(tableIdent, partitionKeys, false)
(tableIdent, cols, partitionKeys, false)
}

/**
Expand All @@ -315,6 +318,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
assert(ctx.OVERWRITE() != null)
val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier)
val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)

val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
Expand All @@ -323,7 +327,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
dynamicPartitionKeys.keys.mkString(", "), ctx)
}

(tableIdent, partitionKeys, ctx.EXISTS() != null)
(tableIdent, cols, partitionKeys, ctx.EXISTS() != null)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ case class DropViewStatement(
* An INSERT INTO statement, as parsed from SQL.
*
* @param table the logical plan representing the table.
* @param userSpecifiedCols the user specified list of columns that belong to the table.
* @param query the logical plan representing data to write to.
* @param overwrite overwrite existing table or partitions.
* @param partitionSpec a map from the partition key to the partition value (optional).
Expand All @@ -290,6 +291,7 @@ case class DropViewStatement(
case class InsertIntoStatement(
table: LogicalPlan,
partitionSpec: Map[String, Option[String]],
userSpecifiedCols: Seq[String],
query: LogicalPlan,
overwrite: Boolean,
ifPartitionNotExists: Boolean) extends ParsedStatement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
}

test("insert table: basic append with a column list") {
Seq(
"INSERT INTO TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT INTO testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -862,6 +878,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("testcat2", "db", "tbl"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -876,6 +893,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}

test("insert table: append with partition and a column list") {
parseCompare(
"""
|INSERT INTO testcat.ns1.ns2.tbl
|PARTITION (p1 = 3, p2) (a, b)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
Expand All @@ -889,6 +922,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
}

test("insert table: overwrite with column list") {
Seq(
"INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT OVERWRITE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -904,6 +953,22 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}

test("insert table: overwrite with partition and column list") {
parseCompare(
"""
|INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl
|PARTITION (p1 = 3, p2) (a, b)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b"),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
Expand All @@ -918,6 +983,7 @@ class DDLParserSuite extends AnalysisTest {
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3")),
Nil,
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = true))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class PlanParserSuite extends AnalysisTest {
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifPartitionNotExists: Boolean = false): LogicalPlan =
InsertIntoStatement(table("s"), partition, plan, overwrite, ifPartitionNotExists)
InsertIntoStatement(table("s"), partition, Nil, plan, overwrite, ifPartitionNotExists)

// Single inserts
assertEqual(s"insert overwrite table s $sql",
Expand Down Expand Up @@ -713,7 +713,7 @@ class PlanParserSuite extends AnalysisTest {
comparePlans(
parsePlan(
"INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"),
InsertIntoStatement(table("s"), Map.empty,
InsertIntoStatement(table("s"), Map.empty, Nil,
UnresolvedHint("REPARTITION", Seq(Literal(100)),
UnresolvedHint("COALESCE", Seq(Literal(500)),
UnresolvedHint("COALESCE", Seq(Literal(10)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
InsertIntoStatement(
table = UnresolvedRelation(tableIdent),
partitionSpec = Map.empty[String, Option[String]],
Nil,
query = df.logicalPlan,
overwrite = mode == SaveMode.Overwrite,
ifPartitionNotExists = false)
Expand Down
Loading