Skip to content

Commit

Permalink
Simplify tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 4, 2020
1 parent 7175e7c commit 7800474
Show file tree
Hide file tree
Showing 7 changed files with 559 additions and 646 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

/**
* A helper trait that provides convenient facilities for file-based data source testing.
Expand Down Expand Up @@ -103,4 +105,40 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
df: DataFrame, path: File): Unit = {
df.write.mode(SaveMode.Overwrite).format(dataSourceName).save(path.getCanonicalPath)
}

/**
* Takes single level `inputDF` dataframe to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with specified datasource.
*/
protected def withNestedDataFrame(inputDF: DataFrame): Seq[(DataFrame, String, Any => Any)] = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,12 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
(f: String => Unit): Unit = withDataSourceFile(data)(f)

/**
* Writes `df` dataframe to a Orc file and reads it back as a `DataFrame`,
* Writes `date` dataframe to a Orc file and reads it back as a `DataFrame`,
* which is then passed to `f`. The Orc file will be deleted after `f` returns.
*/
protected def withOrcDataFrame(df: DataFrame, testVectorized: Boolean = true)
(f: DataFrame => Unit): Unit = {
withTempPath { file =>
df.write.format(dataSourceName).save(file.getCanonicalPath)
readFile(file.getCanonicalPath, testVectorized)(f)
}
}
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T], testVectorized: Boolean = true)
(f: DataFrame => Unit): Unit = withDataSourceDataFrame(data, testVectorized)(f)

/**
* Writes `data` to a Orc file, reads it back as a `DataFrame` and registers it as a
Expand Down Expand Up @@ -147,4 +143,26 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
FileUtils.copyURLToFile(url, file)
spark.read.orc(file.getAbsolutePath)
}

/**
* Takes a sequence of products `data` to generate multi-level nested
* dataframes as new test data. It tests both non-nested and nested dataframes
* which are written and read back with Orc datasource.
*
* This is different from [[OrcTest.withOrcDataFrame]] which does not
* test nested cases.
*/
protected def withNestedOrcDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T])
(runTest: (DataFrame, String, Any => Any) => Unit): Unit =
withNestedOrcDataFrame(spark.createDataFrame(data))(runTest)

protected def withNestedOrcDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readFile(file.getCanonicalPath, true) { df => runTest(df, colName, resultFun) }
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation, NestedColumnPredicateTest, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -66,7 +66,7 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
* within the test.
*/
abstract class ParquetFilterSuite
extends QueryTest with ParquetTest with NestedColumnPredicateTest with SharedSparkSession {
extends QueryTest with ParquetTest with SharedSparkSession {

protected def createParquetFilters(
schema: MessageType,
Expand Down Expand Up @@ -123,34 +123,7 @@ abstract class ParquetFilterSuite

private def withNestedParquetDataFrame(inputDF: DataFrame)
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
assert(inputDF.schema.fields.length == 1)
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
val df = inputDF.toDF("temp")
Seq(
(
df.withColumnRenamed("temp", "a"),
"a", // zero nesting
(x: Any) => x),
(
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
"a.b", // one level nesting
(x: Any) => Row(x)),
(
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
"a.b.c", // two level nesting
(x: Any) => Row(Row(x))
),
(
df.withColumnRenamed("temp", "a.b"),
"`a.b`", // zero nesting with column name containing `dots`
(x: Any) => x
),
(
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
(x: Any) => Row(x)
)
).foreach { case (newDF, colName, resultFun) =>
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
withTempPath { file =>
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) }
Expand Down
Loading

0 comments on commit 7800474

Please sign in to comment.