diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fed100b6069b9..7eb1f5e6a73c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1088,7 +1088,8 @@ class Analyzer( object ResolveInsertInto extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved => + case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) + if i.query.resolved && i.columns.forall(_.resolved) => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { throw new AnalysisException( @@ -1102,22 +1103,13 @@ class Analyzer( val query = addStaticPartitionColumns(r, i.query, staticPartitions) if (!i.overwrite) { - AppendData.byPosition(r, query) + AppendData.byPosition(r, query, i.columns) } else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) { - OverwritePartitionsDynamic.byPosition(r, query) + OverwritePartitionsDynamic.byPosition(r, query, i.columns) } else { - OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions)) + OverwriteByExpression.byPosition( + r, query, staticDeleteExpression(r, staticPartitions), i.columns) } - - case i @ InsertIntoStatement(table, _, _, _, _, _) if table.resolved => - val resolved = i.columns.map { - case u: UnresolvedAttribute => withPosition(u) { - table.resolve(u.nameParts, resolver).map(_.toAttribute) - .getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}")) - } - case other => other - } - i.copy(columns = resolved) } private def partitionColumnNames(table: Table): Seq[String] = { @@ -3038,11 +3030,21 @@ class Analyzer( */ object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { - case append @ AppendData(table, query, _, isByName) + case i @ InsertIntoStatement(table, _, _, _, _, _) if table.resolved => + val resolved = i.columns.map { + case u: UnresolvedAttribute => withPosition(u) { + table.resolve(u.nameParts, resolver).map(_.toAttribute) + .getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}")) + } + case other => other + } + i.copy(columns = resolved) + + case append @ AppendData(table, query, cols, _, isByName) if table.resolved && query.resolved && !append.outputResolved => validateStoreAssignmentPolicy() - val projection = - TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf) + val projection = TableOutputResolver.resolveOutputColumns(table.name, + getOutput(table, cols), query, isByName, conf) if (projection != query) { append.copy(query = projection) @@ -3050,11 +3052,11 @@ class Analyzer( append } - case overwrite @ OverwriteByExpression(table, _, query, _, isByName) + case overwrite @ OverwriteByExpression(table, _, query, cols, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => validateStoreAssignmentPolicy() - val projection = - TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf) + val projection = TableOutputResolver.resolveOutputColumns( + table.name, getOutput(table, cols), query, isByName, conf) if (projection != query) { overwrite.copy(query = projection) @@ -3062,11 +3064,11 @@ class Analyzer( overwrite } - case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName) + case overwrite @ OverwritePartitionsDynamic(table, query, cols, _, isByName) if table.resolved && query.resolved && !overwrite.outputResolved => validateStoreAssignmentPolicy() - val projection = - TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf) + val projection = TableOutputResolver.resolveOutputColumns( + table.name, getOutput(table, cols), query, isByName, conf) if (projection != query) { overwrite.copy(query = projection) @@ -3076,6 +3078,19 @@ class Analyzer( } } + private def getOutput(table: NamedRelation, expectedCols: Seq[Attribute]): Seq[Attribute] = { + if (expectedCols.isEmpty) { + table.output + } else { + if (table.output.size != expectedCols.size) { + failAnalysis(s"${table.name} requires that the data to be inserted have the same number" + + s" of columns as the target table that has ${table.output.size} column(s) but the" + + s" specified part has only ${expectedCols.length} column(s)") + } + expectedCols + } + } + private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 475eb7d74773d..55f9377378806 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -58,6 +58,7 @@ trait V2WriteCommand extends Command { case class AppendData( table: NamedRelation, query: LogicalPlan, + columns: Seq[Attribute], writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand @@ -66,14 +67,15 @@ object AppendData { table: NamedRelation, df: LogicalPlan, writeOptions: Map[String, String] = Map.empty): AppendData = { - new AppendData(table, df, writeOptions, isByName = true) + new AppendData(table, df, Nil, writeOptions, isByName = true) } def byPosition( table: NamedRelation, query: LogicalPlan, + columns: Seq[Attribute] = Nil, writeOptions: Map[String, String] = Map.empty): AppendData = { - new AppendData(table, query, writeOptions, isByName = false) + new AppendData(table, query, columns, writeOptions, isByName = false) } } @@ -84,6 +86,7 @@ case class OverwriteByExpression( table: NamedRelation, deleteExpr: Expression, query: LogicalPlan, + columns: Seq[Attribute], writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand { override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved @@ -95,15 +98,16 @@ object OverwriteByExpression { df: LogicalPlan, deleteExpr: Expression, writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true) + OverwriteByExpression(table, deleteExpr, df, Nil, writeOptions, isByName = true) } def byPosition( table: NamedRelation, query: LogicalPlan, deleteExpr: Expression, + columns: Seq[Attribute] = Nil, writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = { - OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false) + OverwriteByExpression(table, deleteExpr, query, columns, writeOptions, isByName = false) } } @@ -113,6 +117,7 @@ object OverwriteByExpression { case class OverwritePartitionsDynamic( table: NamedRelation, query: LogicalPlan, + columns: Seq[Attribute] = Nil, writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand @@ -121,14 +126,15 @@ object OverwritePartitionsDynamic { table: NamedRelation, df: LogicalPlan, writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, df, writeOptions, isByName = true) + OverwritePartitionsDynamic(table, df, Nil, writeOptions, isByName = true) } def byPosition( table: NamedRelation, query: LogicalPlan, + columns: Seq[Attribute] = Nil, writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = { - OverwritePartitionsDynamic(table, query, writeOptions, isByName = false) + OverwritePartitionsDynamic(table, query, columns, writeOptions, isByName = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index c29eb621cdef3..cc9f3b9a123aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String @@ -843,6 +842,21 @@ class DDLParserSuite extends AnalysisTest { } } + test("insert table: basic append with a column list") { + Seq( + "INSERT INTO TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source", + "INSERT INTO testcat.ns1.ns2.tbl (a, b) SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Seq("a", "b").map(UnresolvedAttribute(_)), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) + } + } + test("insert table: append from another catalog") { parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM testcat2.db.tbl", InsertIntoStatement( @@ -868,6 +882,21 @@ class DDLParserSuite extends AnalysisTest { overwrite = false, ifPartitionNotExists = false)) } + test("insert table: append with partition and a column list") { + parseCompare( + """ + |INSERT INTO testcat.ns1.ns2.tbl (a, b) + |PARTITION (p1 = 3, p2) + |SELECT * FROM source + """.stripMargin, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Seq("a", "b").map(UnresolvedAttribute(_)), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) + } + test("insert table: overwrite") { Seq( "INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl SELECT * FROM source", @@ -883,6 +912,21 @@ class DDLParserSuite extends AnalysisTest { } } + test("insert table: overwrite with column list") { + Seq( + "INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source", + "INSERT OVERWRITE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Seq("a", "b").map(UnresolvedAttribute(_)), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) + } + } + test("insert table: overwrite with partition") { parseCompare( """ @@ -898,6 +942,20 @@ class DDLParserSuite extends AnalysisTest { overwrite = true, ifPartitionNotExists = false)) } + test("insert table: overwrite with partition and column list") { + parseCompare( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b) + |PARTITION (p1 = 3, p2) + |SELECT * FROM source + """.stripMargin, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Seq("a", "b").map(UnresolvedAttribute(_)), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) + } test("insert table: overwrite with partition if not exists") { parseCompare( """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 8b6e5fb8ceabb..03fc2c60a2b6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -512,7 +512,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val command = mode match { case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore => - AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap) + AppendData.byPosition(table, df.logicalPlan, Nil, extraOptions.toMap) case SaveMode.Overwrite => val conf = df.sparkSession.sessionState.conf @@ -520,9 +520,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap) + OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, Nil, extraOptions.toMap) } else { - OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap) + OverwriteByExpression.byPosition( + table, df.logicalPlan, Literal(true), Nil, extraOptions.toMap) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c5ddba43a56aa..66ac768a554f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -167,7 +167,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil } - case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => + case AppendData(r: DataSourceV2Relation, query, _, writeOptions, _) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil @@ -175,7 +175,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _, writeOptions, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { filter => DataSourceStrategy.translateFilter(deleteExpr, @@ -189,7 +189,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil } - case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, writeOptions, _) => OverwritePartitionsDynamicExec( r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala index 5dfd2e52706d0..e1bf013789b2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala @@ -49,14 +49,14 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { // TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a // a logical plan for streaming write. - case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) => + case AppendData(r: DataSourceV2Relation, _, _, _, _) if !supportsBatchWrite(r.table) => failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.") - case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _) + case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _, _) if !r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_DYNAMIC) => failAnalysis(s"Table ${r.table.name()} does not support dynamic overwrite in batch mode.") - case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) => + case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _, _) => expr match { case Literal(true, BooleanType) => if (!supportsBatchWrite(r.table) ||