diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala index 38e25131df867..38bb246d02184 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -30,6 +30,7 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { // Single column of images named "image" private lazy val imagePath = "../data/mllib/images/partitioned" + private lazy val recursiveImagePath = "../data/mllib/images" test("image datasource count test") { val df1 = spark.read.format("image").load(imagePath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 3c932555179fd..3adec2f790730 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -62,6 +62,10 @@ abstract class PartitioningAwareFileIndex( pathGlobFilter.forall(_.accept(file.getPath)) } + protected lazy val recursiveFileLookup = { + parameters.getOrElse("recursiveFileLookup", "false").toBoolean + } + override def listFiles( partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { def isNonEmptyFile(f: FileStatus): Boolean = { @@ -70,6 +74,10 @@ abstract class PartitioningAwareFileIndex( val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { PartitionDirectory(InternalRow.empty, allFiles().filter(isNonEmptyFile)) :: Nil } else { + if (recursiveFileLookup) { + throw new IllegalArgumentException( + "Datasource with partition do not allow recursive file loading.") + } prunePartitions(partitionFilters, partitionSpec()).map { case PartitionPath(values, path) => val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { @@ -95,7 +103,7 @@ abstract class PartitioningAwareFileIndex( override def sizeInBytes: Long = allFiles().map(_.getLen).sum def allFiles(): Seq[FileStatus] = { - val files = if (partitionSpec().partitionColumns.isEmpty) { + val files = if (partitionSpec().partitionColumns.isEmpty && !recursiveFileLookup) { // For each of the root input paths, get the list of files inside them rootPaths.flatMap { path => // Make the path qualified (consistent with listLeafFiles and bulkListLeafFiles). @@ -128,23 +136,27 @@ abstract class PartitioningAwareFileIndex( } protected def inferPartitioning(): PartitionSpec = { - // We use leaf dirs containing data files to discover the schema. - val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => - files.exists(f => isDataPath(f.getPath)) - }.keys.toSeq - - val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) - .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - userSpecifiedSchema = userSpecifiedSchema, - caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis, - validatePartitionColumns = sparkSession.sqlContext.conf.validatePartitionColumns, - timeZoneId = timeZoneId) + if (recursiveFileLookup) { + PartitionSpec.emptySpec + } else { + // We use leaf dirs containing data files to discover the schema. + val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => + files.exists(f => isDataPath(f.getPath)) + }.keys.toSeq + + val caseInsensitiveOptions = CaseInsensitiveMap(parameters) + val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) + + PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + userSpecifiedSchema = userSpecifiedSchema, + caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis, + validatePartitionColumns = sparkSession.sqlContext.conf.validatePartitionColumns, + timeZoneId = timeZoneId) + } } private def prunePartitions( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index ea4794e7c8090..b2d6f017ee04e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.{File, FilenameFilter, FileNotFoundException} +import java.nio.file.{Files, StandardOpenOption} import java.util.Locale import scala.collection.mutable @@ -572,6 +573,75 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } } + test("Option recursiveFileLookup: recursive loading correctly") { + + val expectedFileList = mutable.ListBuffer[String]() + + def createFile(dir: File, fileName: String, format: String): Unit = { + val path = new File(dir, s"${fileName}.${format}") + Files.write( + path.toPath, + s"content of ${path.toString}".getBytes, + StandardOpenOption.CREATE, StandardOpenOption.WRITE + ) + val fsPath = new Path(path.getAbsoluteFile.toURI).toString + expectedFileList.append(fsPath) + } + + def createDir(path: File, dirName: String, level: Int): Unit = { + val dir = new File(path, s"dir${dirName}-${level}") + dir.mkdir() + createFile(dir, s"file${level}", "bin") + createFile(dir, s"file${level}", "text") + + if (level < 4) { + // create sub-dir + createDir(dir, "sub0", level + 1) + createDir(dir, "sub1", level + 1) + } + } + + withTempPath { path => + path.mkdir() + createDir(path, "root", 0) + + val dataPath = new File(path, "dirroot-0").getAbsolutePath + val fileList = spark.read.format("binaryFile") + .option("recursiveFileLookup", true) + .load(dataPath) + .select("path").collect().map(_.getString(0)) + + assert(fileList.toSet === expectedFileList.toSet) + + val fileList2 = spark.read.format("binaryFile") + .option("recursiveFileLookup", true) + .option("pathGlobFilter", "*.bin") + .load(dataPath) + .select("path").collect().map(_.getString(0)) + + assert(fileList2.toSet === expectedFileList.filter(_.endsWith(".bin")).toSet) + } + } + + test("Option recursiveFileLookup: disable partition inferring") { + val dataPath = Thread.currentThread().getContextClassLoader + .getResource("test-data/text-partitioned").toString + + val df = spark.read.format("binaryFile") + .option("recursiveFileLookup", true) + .load(dataPath) + + assert(!df.columns.contains("year"), "Expect partition inferring disabled") + val fileList = df.select("path").collect().map(_.getString(0)) + + val expectedFileList = Array( + dataPath + "/year=2014/data.txt", + dataPath + "/year=2015/data.txt" + ).map(path => new Path(path).toString) + + assert(fileList.toSet === expectedFileList.toSet) + } + test("Return correct results when data columns overlap with partition columns") { Seq("parquet", "orc", "json").foreach { format => withTempPath { path =>