Skip to content

Commit

Permalink
V2 Support
Browse files Browse the repository at this point in the history
  • Loading branch information
yaooqinn committed Sep 29, 2020
1 parent 2b164ed commit 086cfa8
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,8 @@ class Analyzer(

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved =>
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
if i.query.resolved && i.columns.forall(_.resolved) =>
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
Expand All @@ -1102,22 +1103,13 @@ class Analyzer(
val query = addStaticPartitionColumns(r, i.query, staticPartitions)

if (!i.overwrite) {
AppendData.byPosition(r, query)
AppendData.byPosition(r, query, i.columns)
} else if (conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC) {
OverwritePartitionsDynamic.byPosition(r, query)
OverwritePartitionsDynamic.byPosition(r, query, i.columns)
} else {
OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r, staticPartitions))
OverwriteByExpression.byPosition(
r, query, staticDeleteExpression(r, staticPartitions), i.columns)
}

case i @ InsertIntoStatement(table, _, _, _, _, _) if table.resolved =>
val resolved = i.columns.map {
case u: UnresolvedAttribute => withPosition(u) {
table.resolve(u.nameParts, resolver).map(_.toAttribute)
.getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}"))
}
case other => other
}
i.copy(columns = resolved)
}

private def partitionColumnNames(table: Table): Seq[String] = {
Expand Down Expand Up @@ -3038,35 +3030,45 @@ class Analyzer(
*/
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case append @ AppendData(table, query, _, isByName)
case i @ InsertIntoStatement(table, _, _, _, _, _) if table.resolved =>
val resolved = i.columns.map {
case u: UnresolvedAttribute => withPosition(u) {
table.resolve(u.nameParts, resolver).map(_.toAttribute)
.getOrElse(failAnalysis(s"Cannot resolve column name ${u.name}"))
}
case other => other
}
i.copy(columns = resolved)

case append @ AppendData(table, query, cols, _, isByName)
if table.resolved && query.resolved && !append.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(table.name,
getOutput(table, cols), query, isByName, conf)

if (projection != query) {
append.copy(query = projection)
} else {
append
}

case overwrite @ OverwriteByExpression(table, _, query, _, isByName)
case overwrite @ OverwriteByExpression(table, _, query, cols, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(
table.name, getOutput(table, cols), query, isByName, conf)

if (projection != query) {
overwrite.copy(query = projection)
} else {
overwrite
}

case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName)
case overwrite @ OverwritePartitionsDynamic(table, query, cols, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
validateStoreAssignmentPolicy()
val projection =
TableOutputResolver.resolveOutputColumns(table.name, table.output, query, isByName, conf)
val projection = TableOutputResolver.resolveOutputColumns(
table.name, getOutput(table, cols), query, isByName, conf)

if (projection != query) {
overwrite.copy(query = projection)
Expand All @@ -3076,6 +3078,19 @@ class Analyzer(
}
}

private def getOutput(table: NamedRelation, expectedCols: Seq[Attribute]): Seq[Attribute] = {
if (expectedCols.isEmpty) {
table.output
} else {
if (table.output.size != expectedCols.size) {
failAnalysis(s"${table.name} requires that the data to be inserted have the same number" +
s" of columns as the target table that has ${table.output.size} column(s) but the" +
s" specified part has only ${expectedCols.length} column(s)")
}
expectedCols
}
}

private def validateStoreAssignmentPolicy(): Unit = {
// SPARK-28730: LEGACY store assignment policy is disallowed in data source v2.
if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ trait V2WriteCommand extends Command {
case class AppendData(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute],
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand

Expand All @@ -66,14 +67,15 @@ object AppendData {
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, df, writeOptions, isByName = true)
new AppendData(table, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): AppendData = {
new AppendData(table, query, writeOptions, isByName = false)
new AppendData(table, query, columns, writeOptions, isByName = false)
}
}

Expand All @@ -84,6 +86,7 @@ case class OverwriteByExpression(
table: NamedRelation,
deleteExpr: Expression,
query: LogicalPlan,
columns: Seq[Attribute],
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved
Expand All @@ -95,15 +98,16 @@ object OverwriteByExpression {
df: LogicalPlan,
deleteExpr: Expression,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true)
OverwriteByExpression(table, deleteExpr, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
deleteExpr: Expression,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false)
OverwriteByExpression(table, deleteExpr, query, columns, writeOptions, isByName = false)
}
}

Expand All @@ -113,6 +117,7 @@ object OverwriteByExpression {
case class OverwritePartitionsDynamic(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand

Expand All @@ -121,14 +126,15 @@ object OverwritePartitionsDynamic {
table: NamedRelation,
df: LogicalPlan,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, df, writeOptions, isByName = true)
OverwritePartitionsDynamic(table, df, Nil, writeOptions, isByName = true)
}

def byPosition(
table: NamedRelation,
query: LogicalPlan,
columns: Seq[Attribute] = Nil,
writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
OverwritePartitionsDynamic(table, query, writeOptions, isByName = false)
OverwritePartitionsDynamic(table, query, columns, writeOptions, isByName = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first}
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -843,6 +842,21 @@ class DDLParserSuite extends AnalysisTest {
}
}

test("insert table: basic append with a column list") {
Seq(
"INSERT INTO TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT INTO testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b").map(UnresolvedAttribute(_)),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}
}

test("insert table: append from another catalog") {
parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM testcat2.db.tbl",
InsertIntoStatement(
Expand All @@ -868,6 +882,21 @@ class DDLParserSuite extends AnalysisTest {
overwrite = false, ifPartitionNotExists = false))
}

test("insert table: append with partition and a column list") {
parseCompare(
"""
|INSERT INTO testcat.ns1.ns2.tbl (a, b)
|PARTITION (p1 = 3, p2)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b").map(UnresolvedAttribute(_)),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = false, ifPartitionNotExists = false))
}

test("insert table: overwrite") {
Seq(
"INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl SELECT * FROM source",
Expand All @@ -883,6 +912,21 @@ class DDLParserSuite extends AnalysisTest {
}
}

test("insert table: overwrite with column list") {
Seq(
"INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source",
"INSERT OVERWRITE testcat.ns1.ns2.tbl (a, b) SELECT * FROM source"
).foreach { sql =>
parseCompare(sql,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map.empty,
Seq("a", "b").map(UnresolvedAttribute(_)),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
}

test("insert table: overwrite with partition") {
parseCompare(
"""
Expand All @@ -898,6 +942,20 @@ class DDLParserSuite extends AnalysisTest {
overwrite = true, ifPartitionNotExists = false))
}

test("insert table: overwrite with partition and column list") {
parseCompare(
"""
|INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl (a, b)
|PARTITION (p1 = 3, p2)
|SELECT * FROM source
""".stripMargin,
InsertIntoStatement(
UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")),
Map("p1" -> Some("3"), "p2" -> None),
Seq("a", "b").map(UnresolvedAttribute(_)),
Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))),
overwrite = true, ifPartitionNotExists = false))
}
test("insert table: overwrite with partition if not exists") {
parseCompare(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,17 +512,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val command = mode match {
case SaveMode.Append | SaveMode.ErrorIfExists | SaveMode.Ignore =>
AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap)
AppendData.byPosition(table, df.logicalPlan, Nil, extraOptions.toMap)

case SaveMode.Overwrite =>
val conf = df.sparkSession.sessionState.conf
val dynamicPartitionOverwrite = table.table.partitioning.size > 0 &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap)
OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, Nil, extraOptions.toMap)
} else {
OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap)
OverwriteByExpression.byPosition(
table, df.logicalPlan, Literal(true), Nil, extraOptions.toMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
orCreate = orCreate) :: Nil
}

case AppendData(r: DataSourceV2Relation, query, writeOptions, _) =>
case AppendData(r: DataSourceV2Relation, query, _, writeOptions, _) =>
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil
case v2 =>
AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil
}

case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) =>
case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _, writeOptions, _) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
val filters = splitConjunctivePredicates(deleteExpr).map {
filter => DataSourceStrategy.translateFilter(deleteExpr,
Expand All @@ -189,7 +189,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil
}

case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) =>
case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _, writeOptions, _) =>
OverwritePartitionsDynamicExec(
r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) {

// TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a
// a logical plan for streaming write.
case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) =>
case AppendData(r: DataSourceV2Relation, _, _, _, _) if !supportsBatchWrite(r.table) =>
failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.")

case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _)
case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _, _)
if !r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_DYNAMIC) =>
failAnalysis(s"Table ${r.table.name()} does not support dynamic overwrite in batch mode.")

case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) =>
case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _, _) =>
expr match {
case Literal(true, BooleanType) =>
if (!supportsBatchWrite(r.table) ||
Expand Down

0 comments on commit 086cfa8

Please sign in to comment.