From 327bb1dead0e5cb5167ca82b09a46655227a0bf9 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 29 Apr 2015 23:02:43 +0800 Subject: [PATCH] Implements partitioning support for data sources API --- .../apache/spark/deploy/SparkHadoopUtil.scala | 47 ++- project/SparkBuild.scala | 1 + .../org/apache/spark/sql/DataFrame.scala | 21 +- .../scala/org/apache/spark/sql/SQLConf.scala | 3 + .../org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 10 +- .../apache/spark/sql/parquet/newParquet.scala | 6 +- .../sql/sources/DataSourceStrategy.scala | 46 +++ .../spark/sql/sources/PartitioningUtils.scala | 207 ++++++++++++ .../apache/spark/sql/sources/commands.scala | 317 +++++++++++++++++- .../org/apache/spark/sql/sources/ddl.scala | 53 ++- .../apache/spark/sql/sources/interfaces.scala | 73 +++- .../org/apache/spark/sql/sources/rules.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 1 + .../sources/CreateTableAsSelectSuite.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 13 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- .../spark/sql/hive/execution/commands.scala | 12 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- .../sql/sources/FSBasedRelationSuite.scala | 244 ++++++++++---- 20 files changed, 936 insertions(+), 134 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala rename sql/{core => hive}/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala (60%) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index b563034457a91..5def549ac411a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -22,22 +22,22 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.util.{Arrays, Comparator} +import scala.collection.JavaConversions._ +import scala.concurrent.duration._ +import scala.language.postfixOps + import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.fs.FileSystem.Statistics +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ -import scala.concurrent.duration._ -import scala.language.postfixOps +import org.apache.spark.{Logging, SparkConf, SparkException} /** * :: DeveloperApi :: @@ -199,13 +199,36 @@ class SparkHadoopUtil extends Logging { * that file. */ def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { - def recurse(path: Path): Array[FileStatus] = { - val (directories, leaves) = fs.listStatus(path).partition(_.isDir) - leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + listLeafStatuses(fs, fs.getFileStatus(basePath)) + } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { + def recurse(status: FileStatus): Seq[FileStatus] = { + val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) + } + + if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus) + } + + def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + listLeafDirStatuses(fs, fs.getFileStatus(basePath)) + } + + def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { + def recurse(status: FileStatus): Seq[FileStatus] = { + val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus] + leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir)) } - val baseStatus = fs.getFileStatus(basePath) - if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + assert(baseStatus.isDir) + recurse(baseStatus) } /** @@ -275,7 +298,7 @@ class SparkHadoopUtil extends Logging { logDebug(text + " matched " + HADOOP_CONF_PATTERN) val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } val eval = Option[String](hadoopConf.get(key)) - .map { value => + .map { value => logDebug("Substituted " + matched + " with " + value) text.replace(matched, value) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 186345af0e60e..146c9ead80233 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -548,6 +548,7 @@ object TestSettings { javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-agentlib:jdwp=transport=dt_socket,server=n,address=127.0.0.1:5005", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27496227f99d4..32c867f2eeabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -27,23 +27,23 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory - import org.apache.commons.lang3.StringUtils + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1372,6 +1372,7 @@ class DataFrame private[sql]( tableName, source, temporary = false, + Array.empty[String], mode, options, logicalPlan) @@ -1473,7 +1474,7 @@ class DataFrame private[sql]( mode: SaveMode, options: java.util.Map[String, String], partitionColumns: java.util.List[String]): Unit = { - ??? + save(source, mode, options.toMap, partitionColumns) } /** @@ -1488,7 +1489,7 @@ class DataFrame private[sql]( source: String, mode: SaveMode, options: Map[String, String]): Unit = { - ResolvedDataSource(sqlContext, source, mode, options, this) + ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this) } /** @@ -1503,7 +1504,7 @@ class DataFrame private[sql]( mode: SaveMode, options: Map[String, String], partitionColumns: Seq[String]): Unit = { - ??? + ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index dcac97beafb04..2f9784b823173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -66,6 +66,9 @@ private[spark] object SQLConf { // to its length exceeds the threshold. val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + // Whether to perform partition discovery when loading external data sources. + val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 28fc9d04436f7..2753c913b6a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -762,7 +762,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def load(source: String, options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, None, source, options) + val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options) DataFrame(this, LogicalRelation(resolved.relation)) } @@ -792,7 +792,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), source, options) + val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options) DataFrame(this, LogicalRelation(resolved.relation)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 56a4689eb58f0..af0029cb84f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -343,9 +343,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) => - val cmd = - CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query) + case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query) + if partitionsCols.nonEmpty => + sys.error("Cannot create temporary partitioned table.") + + case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) => + val cmd = CreateTempTableUsingAsSelect( + tableName, provider, Array.empty[String], mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 85e60733bc57a..96a39416925b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -136,10 +136,6 @@ private[sql] class DefaultSource } } -private[sql] case class Partition(values: Row, path: String) - -private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) - /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will @@ -805,7 +801,7 @@ private[sql] object ParquetRelation2 extends Logging { val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) StructType(metastoreSchema.zip(reorderedParquetSchema).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index b3d71f687a60a..1659a7a1989dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -53,6 +53,25 @@ private[sql] object DataSourceStrategy extends Strategy { filters, (a, _) => t.buildScan(a)) :: Nil + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) => + val selectedPartitions = prunePartitions(filters, t.partitionSpec) + val inputPaths = selectedPartitions.map(_.path).toArray + + // Don't push down predicates that reference partition columns + val pushedFilters = { + val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet + filters.filter { f => + val referencedColumnNames = f.references.map(_.name).toSet + referencedColumnNames.intersect(partitionColumnNames).isEmpty + } + } + + pruneFilterProject( + l, + projectList, + pushedFilters, + (a, f) => t.buildScan(a, f, inputPaths)) :: Nil + case l @ LogicalRelation(t: TableScan) => createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil @@ -63,6 +82,33 @@ private[sql] object DataSourceStrategy extends Strategy { case _ => Nil } + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[Partition] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = + partitionPruningPredicates + .reduceOption(expressions.And) + .getOrElse(Literal(true)) + + val boundPredicate = InterpretedPredicate(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + partitions.filter { case Partition(values, _) => boundPredicate(values) } + } else { + partitions + } + } + // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala new file mode 100644 index 0000000000000..d30f7f65e21c0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} + +import scala.collection.mutable.ArrayBuffer +import scala.util.Try + +import com.google.common.cache.{CacheBuilder, Cache} +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.types._ + +private[sql] case class Partition(values: Row, path: String) + +private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) + +private[sql] object PartitioningUtils { + private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + require(columnNames.size == literals.size) + } + + /** + * Given a group of qualified paths, tries to parse them and returns a partition specification. + * For example, given: + * {{{ + * hdfs://:/path/to/partition/a=1/b=hello/c=3.14 + * hdfs://:/path/to/partition/a=2/b=world/c=6.28 + * }}} + * it returns: + * {{{ + * PartitionSpec( + * partitionColumns = StructType( + * StructField(name = "a", dataType = IntegerType, nullable = true), + * StructField(name = "b", dataType = StringType, nullable = true), + * StructField(name = "c", dataType = DoubleType, nullable = true)), + * partitions = Seq( + * Partition( + * values = Row(1, "hello", 3.14), + * path = "hdfs://:/path/to/partition/a=1/b=hello/c=3.14"), + * Partition( + * values = Row(2, "world", 6.28), + * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) + * }}} + */ + private[sql] def parsePartitions( + paths: Seq[Path], + defaultPartitionName: String): PartitionSpec = { + val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) + val fields = { + val (PartitionValues(columnNames, literals)) = partitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + StructField(name, dataType, nullable = true) + } + } + + val partitions = partitionValues.zip(paths).map { + case (PartitionValues(_, literals), path) => + Partition(Row(literals.map(_.value): _*), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } + + /** + * Parses a single partition, returns column names and values of each partition column. For + * example, given: + * {{{ + * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 + * }}} + * it returns: + * {{{ + * PartitionValues( + * Seq("a", "b", "c"), + * Seq( + * Literal.create(42, IntegerType), + * Literal.create("hello", StringType), + * Literal.create(3.14, FloatType))) + * }}} + */ + private[sql] def parsePartition( + path: Path, + defaultPartitionName: String): PartitionValues = { + val columns = ArrayBuffer.empty[(String, Literal)] + // Old Hadoop versions don't have `Path.isRoot` + var finished = path.getParent == null + var chopped = path + + while (!finished) { + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + maybeColumn.foreach(columns += _) + chopped = chopped.getParent + finished = maybeColumn.isEmpty || chopped.getParent == null + } + + val (columnNames, values) = columns.reverse.unzip + PartitionValues(columnNames, values) + } + + private def parsePartitionColumn( + columnSpec: String, + defaultPartitionName: String): Option[(String, Literal)] = { + val equalSignIndex = columnSpec.indexOf('=') + if (equalSignIndex == -1) { + None + } else { + val columnName = columnSpec.take(equalSignIndex) + assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") + + val rawColumnValue = columnSpec.drop(equalSignIndex + 1) + assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") + + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + Some(columnName -> literal) + } + } + + /** + * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- + * casting order is: + * {{{ + * NullType -> + * IntegerType -> LongType -> + * FloatType -> DoubleType -> DecimalType.Unlimited -> + * StringType + * }}} + */ + private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { + // Column names of all partitions must match + val distinctPartitionsColNames = values.map(_.columnNames).distinct + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } + } + + /** + * Converts a string to a `Literal` with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[StringType]]. + */ + private[sql] def inferPartitionColumnValue( + raw: String, + defaultPartitionName: String): Literal = { + // First tries integral types + Try(Literal.create(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal.create(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal.create(null, NullType) + else Literal.create(raw, StringType) + } + } + + private val upCastingOrder: Seq[DataType] = + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + + /** + * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" + * types. + */ + private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + val desiredType = { + val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) + // Falls back to string if all values of this column are null or empty string + if (topType == NullType) StringType else topType + } + + literals.map { case l @ Literal(_, dataType) => + Literal.create(Cast(l, desiredType).eval(), desiredType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index dbdb0d39c26a1..3e97506ba2e25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -16,10 +16,21 @@ */ package org.apache.spark.sql.sources -import org.apache.spark.sql.{DataFrame, SQLContext} +import java.util +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred._ +import org.apache.hadoop.util.Shell + +import org.apache.spark._ +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -41,3 +52,307 @@ private[sql] case class InsertIntoDataSource( Seq.empty[Row] } } + +private[sql] case class InsertIntoFSBasedRelation( + @transient relation: FSBasedRelation, + @transient query: LogicalPlan, + partitionColumns: Array[String], + mode: SaveMode) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + require( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + + val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val outputPath = new Path(relation.paths.head) + + val fs = outputPath.getFileSystem(hadoopConf) + val doInsertion = (mode, fs.exists(outputPath)) match { + case (SaveMode.ErrorIfExists, true) => + sys.error(s"path $outputPath already exists.") + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + } + + if (doInsertion) { + val jobConf = new JobConf(hadoopConf) + jobConf.setOutputKeyClass(classOf[Void]) + jobConf.setOutputValueClass(classOf[Row]) + FileOutputFormat.setOutputPath(jobConf, outputPath) + + val df = sqlContext.createDataFrame( + DataFrame(sqlContext, query).queryExecution.toRdd, + relation.schema, + needsConversion = false) + + if (partitionColumns.isEmpty) { + insert(new WriterContainer(relation, jobConf), df) + } else { + val writerContainer = new DynamicPartitionWriterContainer( + relation, jobConf, partitionColumns, "__HIVE_DEFAULT_PARTITION__") + insertWithDynamicPartitions(writerContainer, df, partitionColumns) + } + } + + Seq.empty[Row] + } + + private def insert(writerContainer: WriterContainer, df: DataFrame): Unit = { + try { + writerContainer.driverSideSetup() + df.sqlContext.sparkContext.runJob(df.rdd, writeRows _) + writerContainer.commitJob() + } catch { case cause: Throwable => + writerContainer.abortJob() + throw new SparkException("Job aborted.", cause) + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { + writerContainer.executorSideSetup(taskContext) + + try { + while (iterator.hasNext) { + val row = iterator.next() + writerContainer.outputWriterForRow(row).write(row) + } + writerContainer.commitTask() + } catch { case cause: Throwable => + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + } + } + + private def insertWithDynamicPartitions( + writerContainer: WriterContainer, + df: DataFrame, + partitionColumns: Array[String]): Unit = { + + val sqlContext = df.sqlContext + + val (partitionRDD, dataRDD) = { + val fieldNames = relation.schema.fieldNames + val (partitionCols, dataCols) = fieldNames.partition(partitionColumns.contains) + val df = sqlContext.createDataFrame( + DataFrame(sqlContext, query).queryExecution.toRdd, + relation.schema, + needsConversion = false) + + assert( + partitionCols.sameElements(partitionColumns), { + val insertionPartitionCols = partitionColumns.mkString(",") + val relationPartitionCols = + relation.partitionSpec.partitionColumns.fieldNames.mkString(",") + s"""Partition columns mismatch. + |Expected: $relationPartitionCols + |Actual: $insertionPartitionCols + """.stripMargin + }) + + val partitionDF = df.select(partitionCols.head, partitionCols.tail: _*) + val dataDF = df.select(dataCols.head, dataCols.tail: _*) + + (partitionDF.rdd, dataDF.rdd) + } + + try { + writerContainer.driverSideSetup() + sqlContext.sparkContext.runJob(partitionRDD.zip(dataRDD), writeRows _) + writerContainer.commitJob() + relation.refreshPartitions() + } catch { case cause: Throwable => + writerContainer.abortJob() + throw new SparkException("Job aborted.", cause) + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[(Row, Row)]): Unit = { + writerContainer.executorSideSetup(taskContext) + + try { + while (iterator.hasNext) { + val (partitionValues, data) = iterator.next() + writerContainer.outputWriterForRow(partitionValues).write(data) + } + + writerContainer.commitTask() + } catch { case cause: Throwable => + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } + } + } +} + +private[sql] class WriterContainer( + @transient val relation: FSBasedRelation, + @transient jobConf: JobConf) + extends SparkHadoopMapRedUtil + with Logging + with Serializable { + + protected val serializableJobConf = new SerializableWritable(jobConf) + + // This is only used on driver side. + @transient private var jobContext: JobContext = _ + + // This is only used on executor side. + @transient private var taskAttemptContext: TaskAttemptContext = _ + + // The following fields are initialized and used on both driver and executor side. + @transient private var outputCommitter: OutputCommitter = _ + @transient private var jobId: JobID = _ + @transient private var taskId: TaskID = _ + @transient private var taskAttemptId: TaskAttemptID = _ + + protected val outputPath = { + assert( + relation.paths.length == 1, + s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") + relation.paths.head + } + + protected val dataSchema = relation.dataSchema + + protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass + + // All output writers are created on executor side. + @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _ + + def driverSideSetup(): Unit = { + setupIDs(0, 0, 0) + setupJobConf() + jobContext = newJobContext(jobConf, jobId) + outputCommitter = jobConf.getOutputCommitter + outputCommitter.setupJob(jobContext) + } + + def executorSideSetup(taskContext: TaskContext): Unit = { + setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) + setupJobConf() + taskAttemptContext = newTaskAttemptContext(serializableJobConf.value, taskAttemptId) + outputCommitter = serializableJobConf.value.getOutputCommitter + outputCommitter.setupTask(taskAttemptContext) + outputWriters = initWriters() + } + + private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { + this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) + this.taskId = new TaskID(this.jobId, true, splitId) + this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + } + + private def setupJobConf(): Unit = { + serializableJobConf.value.set("mapred.job.id", jobId.toString) + serializableJobConf.value.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + serializableJobConf.value.set("mapred.task.id", taskAttemptId.toString) + serializableJobConf.value.setBoolean("mapred.task.is.map", true) + serializableJobConf.value.setInt("mapred.task.partition", 0) + } + + // Called on executor side when writing rows + def outputWriterForRow(row: Row): OutputWriter = outputWriters.values.head + + protected def initWriters(): mutable.Map[String, OutputWriter] = { + val writer = outputWriterClass.newInstance() + writer.init(outputPath, dataSchema, serializableJobConf.value) + mutable.Map(outputPath -> writer) + } + + def commitTask(): Unit = { + outputWriters.values.foreach(_.close()) + SparkHadoopMapRedUtil.commitTask( + outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + } + + def abortTask(): Unit = { + outputWriters.values.foreach(_.close()) + outputCommitter.abortTask(taskAttemptContext) + logError(s"Task attempt $taskAttemptId aborted.") + } + + def commitJob(): Unit = { + outputCommitter.commitJob(jobContext) + logInfo(s"Job $jobId committed.") + } + + def abortJob(): Unit = { + outputCommitter.abortJob(jobContext, JobStatus.FAILED) + logError(s"Job $jobId aborted.") + } +} + +private[sql] class DynamicPartitionWriterContainer( + @transient relation: FSBasedRelation, + @transient conf: JobConf, + partitionColumns: Array[String], + defaultPartitionName: String) + extends WriterContainer(relation, conf) { + + override protected def initWriters() = mutable.Map.empty[String, OutputWriter] + + override def outputWriterForRow(row: Row): OutputWriter = { + val partitionPath = partitionColumns.zip(row.toSeq).map { case (col, rawValue) => + val string = if (rawValue == null) null else String.valueOf(rawValue) + val valueString = if (string == null || string.isEmpty) { + defaultPartitionName + } else { + escapePathName(string) + } + s"/$col=$valueString" + }.mkString + + outputWriters.getOrElseUpdate(partitionPath, { + val path = new Path(outputPath, partitionPath.stripPrefix(Path.SEPARATOR)) + val writer = outputWriterClass.newInstance() + writer.init(path.toString, dataSchema, serializableJobConf.value) + writer + }) + } + + private def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (DynamicPartitionWriterContainer.needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02x") + } else { + builder.append(c) + } + } + + builder.toString() + } +} + +private[sql] object DynamicPartitionWriterContainer { + val charToEscape = { + val bitSet = new util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 06c64f2bdd59e..39615e946a308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.sources -import scala.language.existentials +import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex -import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -111,6 +110,7 @@ private[sql] class DDLParser( CreateTableUsingAsSelect(tableName, provider, temp.isDefined, + Array.empty[String], mode, options, queryPlan) @@ -214,6 +214,7 @@ private[sql] object ResolvedDataSource { def apply( sqlContext: SQLContext, userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], provider: String, options: Map[String, String]): ResolvedDataSource = { val clazz: Class[_] = lookupDataSource(provider) @@ -222,6 +223,12 @@ private[sql] object ResolvedDataSource { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: FSBasedRelationProvider => + dataSource.createRelation( + sqlContext, + Some(schema), + Some(partitionColumnsSchema(schema, partitionColumns)), + new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.RelationProvider => throw new AnalysisException(s"$className does not allow user-specified schemas.") case _ => @@ -231,20 +238,34 @@ private[sql] object ResolvedDataSource { case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: FSBasedRelationProvider => + dataSource.createRelation(sqlContext, None, None, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => throw new AnalysisException( s"A schema needs to be specified when using $className.") case _ => - throw new AnalysisException(s"$className is not a RelationProvider.") + throw new AnalysisException( + s"$className is neither a RelationProvider nor a FSBasedRelationProvider.") } } new ResolvedDataSource(clazz, relation) } + private def partitionColumnsSchema( + schema: StructType, + partitionColumns: Array[String]): StructType = { + StructType(partitionColumns.map { col => + schema.find(_.name == col).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $schema") + } + }) + } + /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ def apply( sqlContext: SQLContext, provider: String, + partitionColumns: Array[String], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -252,6 +273,19 @@ private[sql] object ResolvedDataSource { val relation = clazz.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sqlContext, mode, options, data) + case dataSource: FSBasedRelationProvider => + val r = dataSource.createRelation( + sqlContext, + Some(data.schema), + Some(partitionColumnsSchema(data.schema, partitionColumns)), + options) + sqlContext.executePlan( + InsertIntoFSBasedRelation( + r.asInstanceOf[FSBasedRelation], + data.logicalPlan, + partitionColumns.toArray, + mode)).toRdd + r case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } @@ -310,6 +344,7 @@ private[sql] case class CreateTableUsingAsSelect( tableName: String, provider: String, temporary: Boolean, + partitionColumns: Array[String], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -324,8 +359,9 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[Row] = { - val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) + def run(sqlContext: SQLContext): Seq[Row] = { + val resolved = ResolvedDataSource( + sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty @@ -335,13 +371,14 @@ private[sql] case class CreateTempTableUsing( private[sql] case class CreateTempTableUsingAsSelect( tableName: String, provider: String, + partitionColumns: Array[String], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df) + val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 094ccefd8c832..d25034d900ece 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.sources import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** @@ -109,9 +111,9 @@ trait FSBasedRelationProvider { */ def createRelation( sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType, - partitionColumns: StructType): BaseRelation + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): BaseRelation } @DeveloperApi @@ -259,13 +261,11 @@ abstract class OutputWriter { * @param path The file path to which this [[OutputWriter]] is supposed to write. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the * schema if the corresponding relation is partitioned. - * @param options Data source options inherited from driver side. * @param conf Hadoop configuration inherited from driver side. */ def init( path: String, dataSchema: StructType, - options: java.util.Map[String, String], conf: Configuration): Unit = () /** @@ -294,18 +294,65 @@ abstract class OutputWriter { * * For the write path, it provides the ability to write to both non-partitioned and partitioned * tables. Directory layout of the partitioned tables is compatible with Hive. + * + * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for + * implementing metastore table conversion. + * @param paths Base paths of this relation. For partitioned relations, it should be either root + * directories of all partition directories. + * @param maybePartitionSpec An [[FSBasedRelation]] can be created with an optional + * [[PartitionSpec]], so that partition discovery can be skipped. */ @Experimental -abstract class FSBasedRelation extends BaseRelation { - // Discovers partitioned columns, and merge them with `dataSchema`. All partition columns not - // existed in `dataSchema` should be appended to `dataSchema`. - override val schema: StructType = ??? +abstract class FSBasedRelation private[sql]( + val paths: Array[String], + maybePartitionSpec: Option[PartitionSpec]) + extends BaseRelation { /** - * Base path of this relation. For partitioned relations, `path` should be the root directory of - * all partition directories. + * Constructs an [[FSBasedRelation]]. + * + * @param paths Base paths of this relation. For partitioned relations, it should be either root + * directories of all partition directories. */ - def path: String + def this(paths: Array[String]) = this(paths, None) + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + private var _partitionSpec: PartitionSpec = _ + refreshPartitions() + + private[sql] def partitionSpec: PartitionSpec = _partitionSpec + + private[sql] def refreshPartitions(): Unit = { + _partitionSpec = maybePartitionSpec.getOrElse { + val basePaths = paths.map(new Path(_)) + val leafDirs = basePaths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + if (fs.exists(path)) { + SparkHadoopUtil.get.listLeafDirStatuses(fs, path) + } else { + Seq.empty[FileStatus] + } + }.map(_.getPath) + + if (leafDirs.nonEmpty) { + PartitioningUtils.parsePartitions(leafDirs, "__HIVE_DEFAULT_PARTITION__") + } else { + PartitionSpec(StructType(Array.empty[StructField]), Array.empty[Partition]) + } + } + } + + /** + * Schema of this relation. It consists of [[dataSchema]] and all partition columns not appeared + * in [[dataSchema]]. + */ + override val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } /** * Specifies schema of actual data files. For partitioned relations, if one or more partitioned diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 6ed68d179edc9..940e5c96e2b0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -107,7 +107,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, query) => + case CreateTableUsingAsSelect(tableName, _, _, _, SaveMode.Overwrite, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. if (catalog.tableExists(Seq(tableName))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index b7561ce7298cb..2a9876fa899b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.parquet.ParquetRelation2._ +import org.apache.spark.sql.sources.{Partition, PartitionSpec} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLContext} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 54f2f3cdec298..d17391b49031d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sources import java.io.{IOException, File} -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{SQLContext, AnalysisException} import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.Utils diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index bbf48efb24440..d754c8e3a8aa1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,25 +19,24 @@ package org.apache.spark.sql.hive import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.metastore.Warehouse +import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.Logging -import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, Catalog, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} import org.apache.spark.util.Utils /* Implicit conversions */ @@ -98,6 +97,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ResolvedDataSource( hive, userSpecifiedSchema, + Array.empty[String], table.properties("spark.sql.sources.provider"), options) @@ -438,6 +438,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive desc.name, hive.conf.defaultDataSourceName, temporary = false, + Array.empty[String], mode, options = Map.empty[String, String], child diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index be9249a8b1f44..d46a127d47d31 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -221,14 +221,14 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => ExecutedCommand( CreateMetastoreDataSource( tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) => + case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index abab1a223a43a..70f61c5656562 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -158,6 +158,7 @@ private[hive] case class CreateMetastoreDataSourceAsSelect( tableName: String, provider: String, + partitionColumns: Array[String], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -189,12 +190,12 @@ case class CreateMetastoreDataSourceAsSelect( return Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. - val resolved = - ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath) + val resolved = ResolvedDataSource( + sqlContext, Some(query.schema), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { - case l @ LogicalRelation(i: InsertableRelation) => - if (i != createdRelation.relation) { + case l @ LogicalRelation(_: InsertableRelation | _: FSBasedRelation) => + if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table $tableName because the resolved relation does not " + s"match the existing relation of $tableName. " + @@ -234,7 +235,8 @@ case class CreateMetastoreDataSourceAsSelect( } // Create the relation based on the data of df. - val resolved = ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df) + val resolved = + ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 8398da268174d..cbc381cc81b59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -204,7 +204,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( if (string == null || string.isEmpty) { defaultPartName } else { - FileUtils.escapePathName(string) + FileUtils.escapePathName(string, defaultPartName) } s"/$col=$colString" }.mkString diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala similarity index 60% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala index 6c5bb8e425a74..59e01b2690781 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala @@ -19,41 +19,80 @@ package org.apache.spark.sql.sources import java.io.IOException -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable +import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.scalatest.BeforeAndAfter import org.apache.spark.rdd.RDD +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLContext, SaveMode} +import org.apache.spark.sql._ import org.apache.spark.util.Utils -class SimpleFSBasedSource extends RelationProvider { +class SimpleFSBasedSource extends FSBasedRelationProvider { override def createRelation( sqlContext: SQLContext, + schema: Option[StructType], + partitionColumns: Option[StructType], parameters: Map[String, String]): BaseRelation = { - SimpleFSBasedRelation(parameters)(sqlContext) + val path = parameters("path") + + // Uses data sources options to simulate data schema + val dataSchema = DataType.fromJson(parameters("schema")).asInstanceOf[StructType] + + // Uses data sources options to mock partition discovery + val maybePartitionSpec = + parameters.get("partitionColumns").map { json => + new PartitionSpec( + DataType.fromJson(json).asInstanceOf[StructType], + Array.empty[Partition]) + } + + SimpleFSBasedRelation(path, dataSchema, maybePartitionSpec, parameters)(sqlContext) } } -case class SimpleFSBasedRelation - (parameter: Map[String, String]) - (val sqlContext: SQLContext) - extends FSBasedRelation { +class SimpleOutputWriter extends OutputWriter { + override def init(path: String, dataSchema: StructType, conf: Configuration): Unit = { + TestResult.synchronized { + TestResult.writerPaths += path + } + } - class SimpleOutputWriter extends OutputWriter { - override def write(row: Row): Unit = TestResult.writtenRows += row + override def write(row: Row): Unit = { + TestResult.synchronized { + TestResult.writtenRows += row + } } +} - override val path = parameter("path") +case class SimpleFSBasedRelation + (path: String, + dataSchema: StructType, + maybePartitionSpec: Option[PartitionSpec], + parameter: Map[String, String]) + (@transient val sqlContext: SQLContext) + extends FSBasedRelation(Array(path), maybePartitionSpec) { + + override def equals(obj: scala.Any): Boolean = obj match { + case that: SimpleFSBasedRelation => + this.path == that.path && + this.dataSchema == that.dataSchema && + this.maybePartitionSpec == that.maybePartitionSpec + case _ => false + } - override def dataSchema: StructType = - DataType.fromJson(parameter("schema")).asInstanceOf[StructType] + override def hashCode(): Int = + Objects.hashCode(path, dataSchema, maybePartitionSpec) override def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String]): RDD[Row] = { + val sqlContext = this.sqlContext val basePath = new Path(path) val fs = basePath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) @@ -61,11 +100,12 @@ case class SimpleFSBasedRelation // shouldn't be removed before scanning the data. assert(inputPaths.map(new Path(_)).forall(fs.exists)) - TestResult.requiredColumns = requiredColumns - TestResult.filters = filters - TestResult.inputPaths = inputPaths - - Option(TestResult.rowsToRead).getOrElse(sqlContext.emptyResult) + TestResult.synchronized { + TestResult.requiredColumns = requiredColumns + TestResult.filters = filters + TestResult.inputPaths = inputPaths + Option(TestResult.rowsToRead).getOrElse(sqlContext.emptyResult) + } } override def outputWriterClass: Class[_ <: OutputWriter] = classOf[SimpleOutputWriter] @@ -76,20 +116,22 @@ object TestResult { var filters: Array[Filter] = _ var inputPaths: Array[String] = _ var rowsToRead: RDD[Row] = _ - var writtenRows: ArrayBuffer[Row] = ArrayBuffer.empty[Row] + var writerPaths: mutable.Set[String] = mutable.Set.empty[String] + var writtenRows: mutable.Set[Row] = mutable.Set.empty[Row] - def reset(): Unit = { + def reset(): Unit = this.synchronized { requiredColumns = null filters = null inputPaths = null rowsToRead = null + writerPaths.clear() writtenRows.clear() } } -class FSBasedRelationSuite extends DataSourceTest { - import caseInsensitiveContext._ - import caseInsensitiveContext.implicits._ +class FSBasedRelationSuite extends QueryTest with BeforeAndAfter { + import TestHive._ + import TestHive.implicits._ var basePath: Path = _ @@ -101,6 +143,8 @@ class FSBasedRelationSuite extends DataSourceTest { StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = false))) + val partitionColumns = StructType(StructField("p1", IntegerType, nullable = true) :: Nil) + val testDF = (for { i <- 1 to 3 p <- 1 to 2 @@ -109,10 +153,11 @@ class FSBasedRelationSuite extends DataSourceTest { before { basePath = new Path(Utils.createTempDir().getCanonicalPath) fs = basePath.getFileSystem(sparkContext.hadoopConfiguration) + basePath = fs.makeQualified(basePath) TestResult.reset() } - ignore("load() - partitioned table - partition column not included in data files") { + test("load() - partitioned table - partition column not included in data files") { fs.mkdirs(new Path(basePath, "p1=1/p2=hello")) fs.mkdirs(new Path(basePath, "p1=2/p2=world")) @@ -137,15 +182,15 @@ class FSBasedRelationSuite extends DataSourceTest { assert(df.schema === expectedSchema) - df.select("b").where($"a" > 0 && $"p1" === 1).collect() + df.where($"a" > 0 && $"p1" === 1).select("b").collect() // Check for column pruning, filter push-down, and partition pruning assert(TestResult.requiredColumns.toSet === Set("a", "b")) assert(TestResult.filters === Seq(GreaterThan("a", 0))) - assert(TestResult.inputPaths === Seq(new Path(basePath, "p1=1").toString)) + assert(TestResult.inputPaths === Seq(new Path(basePath, "p1=1/p2=hello").toString)) } - ignore("load() - partitioned table - partition column included in data files") { + test("load() - partitioned table - partition column included in data files") { val data = sparkContext.parallelize(Seq.empty[String]) data.saveAsTextFile(new Path(basePath, "p1=1/p2=hello").toString) data.saveAsTextFile(new Path(basePath, "p1=2/p2=world").toString) @@ -180,26 +225,33 @@ class FSBasedRelationSuite extends DataSourceTest { assert(df.schema === expectedSchema) - df.select("b").where($"a" > 0 && $"p1" === 1).collect() + df.where($"a" > 0 && $"p1" === 1).select("b").collect() // Check for column pruning, filter push-down, and partition pruning assert(TestResult.requiredColumns.toSet === Set("a", "b")) assert(TestResult.filters === Seq(GreaterThan("a", 0))) - assert(TestResult.inputPaths === Seq(new Path(basePath, "p1=1").toString)) + assert(TestResult.inputPaths === Seq(new Path(basePath, "p1=1/p2=hello").toString)) } - ignore("save() - partitioned table - Overwrite") { + test("save() - partitioned table - Overwrite") { testDF.save( source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) + Thread.sleep(500) + // Written rows shouldn't contain dynamic partition column val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } } ignore("save() - partitioned table - Overwrite - select and overwrite the same table") { @@ -216,21 +268,27 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) // Written rows shouldn't contain dynamic partition column val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } } - ignore("save() - partitioned table - Append") { + test("save() - partitioned table - Append") { testDF.save( source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) testDF.save( @@ -238,15 +296,20 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.Append, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) // Written rows shouldn't contain dynamic partition column val expectedRows = for (i <- 1 to 3; _ <- 1 to 4) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } } - ignore("save() - partitioned table - ErrorIfExists") { + test("save() - partitioned table - ErrorIfExists") { fs.delete(basePath, true) testDF.save( @@ -254,23 +317,31 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) - assert(TestResult.writtenRows.sameElements(testDF.collect())) + // Written rows shouldn't contain dynamic partition column + val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } - intercept[IOException] { + intercept[RuntimeException] { testDF.save( source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.ErrorIfExists, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) } } - ignore("save() - partitioned table - Ignore") { + test("save() - partitioned table - Ignore") { testDF.save( source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Ignore, @@ -282,19 +353,34 @@ class FSBasedRelationSuite extends DataSourceTest { assert(TestResult.writtenRows.isEmpty) } - ignore("saveAsTable() - partitioned table - Overwrite") { + test("save() - data sources other than FSBasedRelation") { + val cause = intercept[RuntimeException] { + testDF.save( + source = classOf[FilteredScanSource].getCanonicalName, + mode = SaveMode.Overwrite, + options = Map("path" -> basePath.toString), + partitionColumns = Seq("p1")) + } + } + + test("saveAsTable() - partitioned table - Overwrite") { testDF.saveAsTable( tableName = "t", source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) // Written rows shouldn't contain dynamic partition column val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } assertResult(table("t").schema) { StructType( @@ -312,7 +398,8 @@ class FSBasedRelationSuite extends DataSourceTest { source = classOf[SimpleFSBasedSource].getCanonicalName, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json)) + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json)) df.saveAsTable( tableName = "t", @@ -320,12 +407,17 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) // Written rows shouldn't contain dynamic partition column val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } assertResult(table("t").schema) { StructType( @@ -336,14 +428,15 @@ class FSBasedRelationSuite extends DataSourceTest { sql("DROP TABLE t") } - ignore("saveAsTable() - partitioned table - Append") { + test("saveAsTable() - partitioned table - Append") { testDF.saveAsTable( tableName = "t", source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Overwrite, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) testDF.saveAsTable( @@ -352,12 +445,17 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.Append, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) // Written rows shouldn't contain dynamic partition column - val expectedRows = for (i <- 1 to 3; _ <- 1 to 4) yield Row(i, s"val_$i") - assert(TestResult.writtenRows.sameElements(expectedRows)) + val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } assertResult(table("t").schema) { StructType( @@ -368,7 +466,7 @@ class FSBasedRelationSuite extends DataSourceTest { sql("DROP TABLE t") } - ignore("saveAsTable() - partitioned table - ErrorIfExists") { + test("saveAsTable() - partitioned table - ErrorIfExists") { fs.delete(basePath, true) testDF.saveAsTable( @@ -377,10 +475,17 @@ class FSBasedRelationSuite extends DataSourceTest { mode = SaveMode.ErrorIfExists, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) - assert(TestResult.writtenRows.sameElements(testDF.collect())) + // Written rows shouldn't contain dynamic partition column + val expectedRows = for (i <- 1 to 3; _ <- 1 to 2) yield Row(i, s"val_$i") + + TestResult.synchronized { + assert(TestResult.writerPaths.size === 2) + assert(TestResult.writtenRows === expectedRows.toSet) + } assertResult(table("t").schema) { StructType( @@ -388,28 +493,30 @@ class FSBasedRelationSuite extends DataSourceTest { StructField("p1", IntegerType, nullable = true))) } - intercept[IOException] { + intercept[AnalysisException] { testDF.saveAsTable( tableName = "t", source = classOf[SimpleFSBasedSource].getCanonicalName, - mode = SaveMode.Overwrite, + mode = SaveMode.ErrorIfExists, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) } sql("DROP TABLE t") } - ignore("saveAsTable() - partitioned table - Ignore") { + test("saveAsTable() - partitioned table - Ignore") { testDF.saveAsTable( tableName = "t", source = classOf[SimpleFSBasedSource].getCanonicalName, mode = SaveMode.Ignore, options = Map( "path" -> basePath.toString, - "schema" -> dataSchema.json), + "schema" -> dataSchema.json, + "partitionColumns" -> partitionColumns.json), partitionColumns = Seq("p1")) assert(TestResult.writtenRows.isEmpty) @@ -422,4 +529,15 @@ class FSBasedRelationSuite extends DataSourceTest { sql("DROP TABLE t") } + + test("saveAsTable() - data sources other than FSBasedRelation") { + val cause = intercept[RuntimeException] { + testDF.saveAsTable( + tableName = "t", + source = classOf[FilteredScanSource].getCanonicalName, + mode = SaveMode.Overwrite, + options = Map("path" -> basePath.toString), + partitionColumns = Seq("p1")) + } + } }