From 04046e5432acb1132fa567f2230723bc1a92a482 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 4 Dec 2018 00:05:15 +0800 Subject: [PATCH] [SPARK-25498][SQL] InterpretedMutableProjection should handle UnsafeRow ## What changes were proposed in this pull request? Since `AggregationIterator` uses `MutableProjection` for `UnsafeRow`, `InterpretedMutableProjection` needs to handle `UnsafeRow` as buffer internally for fixed-length types only. ## How was this patch tested? Run 'SQLQueryTestSuite' with the interpreted mode. Closes #22512 from maropu/InterpreterTest. Authored-by: Takeshi Yamamuro Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/InternalRow.scala | 22 +++++ .../InterpretedMutableProjection.scala | 23 +++++- .../expressions/ExpressionEvalHelper.scala | 11 +++ .../expressions/MutableProjectionSuite.scala | 81 +++++++++++++++++++ .../expressions/UnsafeRowConverterSuite.scala | 15 +--- .../sql-tests/inputs/change-column.sql | 1 + .../test/resources/sql-tests/inputs/udaf.sql | 3 + .../sql-tests/results/change-column.sql.out | 10 ++- .../resources/sql-tests/results/udaf.sql.out | 18 ++++- .../apache/spark/sql/SQLQueryTestSuite.scala | 27 ++++++- 10 files changed, 192 insertions(+), 19 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e49c10be6be4e..bdab407688a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -157,4 +157,26 @@ object InternalRow { getValueNullSafe } } + + /** + * Returns a writer for an `InternalRow` with given data type. + */ + def getWriter(ordinal: Int, dt: DataType): (InternalRow, Any) => Unit = dt match { + case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean]) + case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) + case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) + case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) + case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) + case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) + case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) + case DecimalType.Fixed(precision, _) => + (input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision) + case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType) + case NullType => (input, _) => input.setNullAt(ordinal) + case StringType => (input, v) => input.update(ordinal, v.asInstanceOf[UTF8String].copy()) + case _: StructType => (input, v) => input.update(ordinal, v.asInstanceOf[InternalRow].copy()) + case _: ArrayType => (input, v) => input.update(ordinal, v.asInstanceOf[ArrayData].copy()) + case _: MapType => (input, v) => input.update(ordinal, v.asInstanceOf[MapData].copy()) + case _ => (input, v) => input.update(ordinal, v) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 0654108cea281..122a564da61be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -49,10 +49,31 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable def currentValue: InternalRow = mutableRow override def target(row: InternalRow): MutableProjection = { + // If `mutableRow` is `UnsafeRow`, `MutableProjection` accepts fixed-length types only + require(!row.isInstanceOf[UnsafeRow] || + validExprs.forall { case (e, _) => UnsafeRow.isFixedLength(e.dataType) }, + "MutableProjection cannot use UnsafeRow for output data types: " + + validExprs.map(_._1.dataType).filterNot(UnsafeRow.isFixedLength) + .map(_.catalogString).mkString(", ")) mutableRow = row this } + private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => + val writer = InternalRow.getWriter(i, e.dataType) + if (!e.nullable) { + (v: Any) => writer(mutableRow, v) + } else { + (v: Any) => { + if (v == null) { + mutableRow.setNullAt(i) + } else { + writer(mutableRow, v) + } + } + } + }.toArray + override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < validExprs.length) { @@ -64,7 +85,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable i = 0 while (i < validExprs.length) { val (_, ordinal) = validExprs(i) - mutableRow(ordinal) = buffer(ordinal) + fieldWriters(i)(buffer(ordinal)) i += 1 } mutableRow diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index eb33325d0b31a..a7282e1b1cadc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -456,4 +456,15 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa diff < eps * math.min(absX, absY) } } + + def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) + for (fallbackMode <- modes) { + test(s"$name with $fallbackMode") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + f + } + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala new file mode 100644 index 0000000000000..2db1c3b98819c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { + + val fixedLengthTypes = Array[DataType]( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, + DateType, TimestampType) + + val variableLengthTypes = Array( + StringType, DecimalType.defaultConcreteType, CalendarIntervalType, BinaryType, + ArrayType(StringType), MapType(IntegerType, StringType), + StructType.fromDDL("a INT, b STRING"), ObjectType(classOf[java.lang.Integer])) + + def createMutableProjection(dataTypes: Array[DataType]): MutableProjection = { + MutableProjection.create(dataTypes.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) + } + + testBothCodegenAndInterpreted("fixed-length types") { + val inputRow = InternalRow.fromSeq(Seq(true, 3.toByte, 15.toShort, -83, 129L, 1.0f, 5.0, 1, 2L)) + val proj = createMutableProjection(fixedLengthTypes) + assert(proj(inputRow) === inputRow) + } + + testBothCodegenAndInterpreted("unsafe buffer") { + val inputRow = InternalRow.fromSeq(Seq(false, 1.toByte, 9.toShort, -18, 53L, 3.2f, 7.8, 4, 9L)) + val numBytes = UnsafeRow.calculateBitSetWidthInBytes(fixedLengthTypes.length) + val unsafeBuffer = UnsafeRow.createFromByteArray(numBytes, fixedLengthTypes.length) + val proj = createMutableProjection(fixedLengthTypes) + val projUnsafeRow = proj.target(unsafeBuffer)(inputRow) + assert(FromUnsafeProjection.apply(fixedLengthTypes)(projUnsafeRow) === inputRow) + } + + testBothCodegenAndInterpreted("variable-length types") { + val proj = createMutableProjection(variableLengthTypes) + val scalaValues = Seq("abc", BigDecimal(10), CalendarInterval.fromString("interval 1 day"), + Array[Byte](1, 2), Array("123", "456"), Map(1 -> "a", 2 -> "b"), Row(1, "a"), + new java.lang.Integer(5)) + val inputRow = InternalRow.fromSeq(scalaValues.zip(variableLengthTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj(inputRow) + variableLengthTypes.zipWithIndex.foreach { case (dataType, index) => + val toScala = CatalystTypeConverters.createToScalaConverter(dataType) + assert(toScala(projRow.get(index, dataType)) === toScala(inputRow.get(index, dataType))) + } + } + + test("unsupported types for unsafe buffer") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString) { + val proj = createMutableProjection(Array(StringType)) + val errMsg = intercept[IllegalArgumentException] { + proj.target(new UnsafeRow(1)) + }.getMessage + assert(errMsg.contains("MutableProjection cannot use UnsafeRow for output data types:")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 5a646d9a850ac..268372b5d0504 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -26,26 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, _} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String -class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestBase + with ExpressionEvalHelper { private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) - private def testBothCodegenAndInterpreted(name: String)(f: => Unit): Unit = { - val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) - for (fallbackMode <- modes) { - test(s"$name with $fallbackMode") { - withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { - f - } - } - } - } - testBothCodegenAndInterpreted("basic conversion with only primitive types") { val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql index 2909024e4c9f7..6f5ac221ce79c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -54,3 +54,4 @@ ALTER TABLE partition_table CHANGE COLUMN c c INT COMMENT 'this is column C'; -- DROP TEST TABLE DROP TABLE test_change; DROP TABLE partition_table; +DROP VIEW global_temp.global_temp_view; diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql index 2183ba23afc38..58613a1325dfa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -11,3 +11,6 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1; CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; SELECT default.udaf1(int_col1) as udaf1 from t1; + +DROP FUNCTION myDoubleAvg; +DROP FUNCTION udaf1; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index ff1ecbcc44c23..114617873af47 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 33 +-- Number of queries: 34 -- !query 0 @@ -313,3 +313,11 @@ DROP TABLE partition_table struct<> -- !query 32 output + + +-- !query 33 +DROP VIEW global_temp.global_temp_view +-- !query 33 schema +struct<> +-- !query 33 output + diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index 87824ab81cdf7..f4455bb717578 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 8 -- !query 0 @@ -52,3 +52,19 @@ struct<> -- !query 5 output org.apache.spark.sql.AnalysisException Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 + + +-- !query 6 +DROP FUNCTION myDoubleAvg +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +DROP FUNCTION udaf1 +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 6ca3ac596e5f4..fd180ce2380a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -22,11 +22,13 @@ import java.util.{Locale, TimeZone} import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -140,6 +142,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val input = fileToString(new File(testCase.inputFile)) val (comments, code) = input.split("\n").partition(_.startsWith("--")) + + // Runs all the tests on both codegen-only and interpreter modes + val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map { + case codegenFactoryMode => + Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString) + } val configSets = { val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5)) val configs = configLines.map(_.split(",").map { confAndValue => @@ -148,12 +156,25 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { }) // When we are regenerating the golden files we don't need to run all the configs as they // all need to return the same result - if (regenerateGoldenFiles && configs.nonEmpty) { - configs.take(1) + if (regenerateGoldenFiles) { + if (configs.nonEmpty) { + configs.take(1) + } else { + Array.empty[Array[(String, String)]] + } } else { - configs + if (configs.nonEmpty) { + codegenConfigSets.flatMap { codegenConfig => + configs.map { config => + config ++ codegenConfig + } + } + } else { + codegenConfigSets + } } } + // List of SQL queries to run // note: this is not a robust way to split queries using semicolon, but works for now. val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq