From dfcec3c21bff4d893de249cf867b2dd23f49eb51 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 13 Jul 2020 18:26:51 +0800 Subject: [PATCH 01/42] [SPARK-32106][SQL]Implement SparkScriptTransformationExec in sql/core --- .../spark/sql/execution/SparkPlanner.scala | 1 + .../SparkScriptTransformationExec.scala | 187 ++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 15 ++ ...ala => HiveScriptTransformationExec.scala} | 2 +- .../BaseScriptTransformationSuite.scala | 200 ++++++++++++++++++ .../HiveScriptTransformationSuite.scala | 167 ++------------- .../SparkScriptTransformationSuite.scala | 37 ++++ 7 files changed, 458 insertions(+), 151 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala rename sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/{ScriptTransformationExec.scala => HiveScriptTransformationExec.scala} (99%) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 895eeedd86b8b..b96a861196897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -46,6 +46,7 @@ class SparkPlanner( Window :: JoinSelection :: InMemoryScans :: + SparkScripts:: BasicOperators :: Nil) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala new file mode 100644 index 0000000000000..c6bbbd140c4ea --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{CircularBuffer, RedirectThread} + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class SparkScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: SparkScriptIOSchema) + extends BaseScriptTransformationExec { + + override def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration) + : Iterator[InternalRow] = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator.map(outputProjection), + input.map(_.dataType), + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { + var curLine: String = null + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + if (!ioschema.schemaLess) { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) + } else { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) + } + } + } + + writerThread.start() + + outputIterator + } +} + +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + ioSchema: SparkScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration) + extends BaseScriptTransformationWriterThread( + iter, + inputSchema, + ioSchema, + outputStream, + proc, + stderrBuffer, + taskContext, + conf) { + + setDaemon(true) + + override def processRows(): Unit = { + processRowsWithoutSerde() + } +} + +object SparkScriptIOSchema { + def apply(input: ScriptInputOutputSchema): SparkScriptIOSchema = { + SparkScriptIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } +} + +/** + * The wrapper class of Spark script transformation input and output schema properties + */ +case class SparkScriptIOSchema ( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) extends BaseScriptTransformIOSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 583e5a2c5c57e..d5366de2ea704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,6 +532,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object SparkScripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child, ioschema) + if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty => + SparkScriptTransformationExec( + input, + script, + output, + planLater(child), + SparkScriptIOSchema(ioschema) + ) :: Nil + case _ => Nil + } + } + /** * This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`. * It won't affect the execution, because `StreamingRelation` will be replaced with diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala similarity index 99% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 96fe646d39fde..098ffd3b7d75a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -275,7 +275,7 @@ object HiveScriptIOSchema { } /** - * The wrapper class of Hive input and output schema properties + * The wrapper class of Hive script transformation input and output schema properties */ case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala new file mode 100644 index 0000000000000..4869872273832 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.sql.Timestamp +import java.util.Locale + +import org.scalatest.Assertions._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, TaskContext, TestUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StringType + +abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils + with TestHiveSingleton with BeforeAndAfterEach { + + def scriptType: String + + import spark.implicits._ + + var noSerdeIOSchema: BaseScriptTransformIOSchema = _ + + private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ + + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler + + protected override def beforeAll(): Unit = { + super.beforeAll() + defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) + } + + protected override def afterAll(): Unit = { + super.afterAll() + Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) + } + + override protected def afterEach(): Unit = { + super.afterEach() + uncaughtExceptionHandler.cleanStatus() + } + + def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: BaseScriptTransformIOSchema): BaseScriptTransformationExec = { + scriptType.toUpperCase(Locale.ROOT) match { + case "SPARK" => new SparkScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema.asInstanceOf[SparkScriptIOSchema] + ) + case "HIVE" => new HiveScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema.asInstanceOf[HiveScriptIOSchema] + ) + case _ => throw new TestFailedException( + "Test class implement from BaseScriptTransformationSuite" + + " should override method `scriptType` to Spark or Hive", 0) + } + } + + test("cat without SerDe") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + ), + rowsDf.collect()) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + ), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("SPARK-25990: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |FROM v + """.stripMargin) + + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (HiveUtils.isHive23) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + checkAnswer(query, identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + } + } + + test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = noSerdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } +} + +case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 35252fc47f49f..3caeca4a0eb30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -17,30 +17,20 @@ package org.apache.spark.sql.hive.execution -import java.sql.Timestamp - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.scalatest.Assertions._ -import org.scalatest.BeforeAndAfterEach import org.scalatest.exceptions.TestFailedException -import org.apache.spark.{SparkException, TaskContext, TestUtils} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{BaseScriptTransformIOSchema, SparkPlan, SparkPlanTest} import org.apache.spark.sql.types.StringType -class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton - with BeforeAndAfterEach { +class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { + override def scriptType: String = "HIVE" + import spark.implicits._ - private val noSerdeIOSchema = HiveScriptIOSchema( + noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, outputRowFormat = Seq.empty, inputSerdeClass = None, @@ -52,46 +42,11 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with schemaLess = false ) - private val serdeIOSchema = noSerdeIOSchema.copy( - inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), - outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) - ) - - private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ - - private val uncaughtExceptionHandler = new TestUncaughtExceptionHandler - - protected override def beforeAll(): Unit = { - super.beforeAll() - defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler - Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) - } - - protected override def afterAll(): Unit = { - super.afterAll() - Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) - } - - override protected def afterEach(): Unit = { - super.afterEach() - uncaughtExceptionHandler.cleanStatus() - } - - test("cat without SerDe") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = child, - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - assert(uncaughtExceptionHandler.exception.isEmpty) + private val serdeIOSchema: BaseScriptTransformIOSchema = { + noSerdeIOSchema.asInstanceOf[HiveScriptIOSchema].copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) } test("cat with LazySimpleSerDe") { @@ -100,7 +55,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -111,27 +66,6 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with assert(uncaughtExceptionHandler.exception.isEmpty) } - test("script transformation should not swallow errors from upstream operators (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[TestFailedException] { - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - } - assert(e.getMessage().contains("intentional exception")) - // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException - assert(uncaughtExceptionHandler.exception.isEmpty) - } - test("script transformation should not swallow errors from upstream operators (with serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) @@ -139,7 +73,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -160,7 +94,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val e = intercept[SparkException] { val plan = - new HiveScriptTransformationExec( + createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -181,7 +115,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("name").expr), script = "cat", output = Seq(AttributeReference("name", StringType)()), @@ -192,67 +126,13 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-25990: TRANSFORM should handle different data types correctly") { - assume(TestUtils.testCommandAvailable("python")) - val scriptFilePath = getTestResourcePath("test_script.py") - - withTempView("v") { - val df = Seq( - (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), - (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), - (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) - ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) - df.createTempView("v") - - val query = sql( - s""" - |SELECT - |TRANSFORM(a, b, c, d, e) - |USING 'python $scriptFilePath' AS (a, b, c, d, e) - |FROM v - """.stripMargin) - - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (HiveUtils.isHive23) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } - checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - decimalToString('d), - 'e.cast("string")).collect()) - } - } - - test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[SparkException] { - val plan = - new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "some_non_existent_command", - output = Seq(AttributeReference("a", StringType)()), - child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) - } - assert(e.getMessage.contains("Subprocess exited with status")) - assert(uncaughtExceptionHandler.exception.isEmpty) - } - test("SPARK-30973: TRANSFORM should wait for the termination of the script (with serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new HiveScriptTransformationExec( + createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -265,16 +145,3 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with } } -private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = { - child.execute().map { x => - assert(TaskContext.get() != null) // Make sure that TaskContext is defined. - Thread.sleep(1000) // This sleep gives the external process time to start. - throw new IllegalArgumentException("intentional exception") - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala new file mode 100644 index 0000000000000..372e2a8054cda --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.execution.SparkScriptIOSchema + +class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { + + override def scriptType: String = "SPARK" + + noSerdeIOSchema = SparkScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, + schemaLess = false + ) +} From e53744b8863d822e0fb6bafe6aa803bb04f2c5cd Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 13 Jul 2020 21:51:37 +0800 Subject: [PATCH 02/42] save --- .../BaseScriptTransformationExec.scala | 26 ++++++++++++------- .../SparkScriptTransformationExec.scala | 14 +++------- .../HiveScriptTransformationExec.scala | 14 +++------- 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 22bf6df58b040..aa54d93d94b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -89,15 +89,23 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } -abstract class BaseScriptTransformationWriterThread( - iter: Iterator[InternalRow], - inputSchema: Seq[DataType], - ioSchema: BaseScriptTransformIOSchema, - outputStream: OutputStream, - proc: Process, - stderrBuffer: CircularBuffer, - taskContext: TaskContext, - conf: Configuration) extends Thread with Logging { +abstract class BaseScriptTransformationWriterThread extends Thread with Logging { + + def iter: Iterator[InternalRow] + + def inputSchema: Seq[DataType] + + def ioSchema: BaseScriptTransformIOSchema + + def outputStream: OutputStream + + def proc: Process + + def stderrBuffer: CircularBuffer + + def taskContext: TaskContext + + def conf: Configuration setDaemon(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index c6bbbd140c4ea..a43ae4401da3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -71,7 +71,7 @@ case class SparkScriptTransformationExec( // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. - val writerThread = new ScriptTransformationWriterThread( + val writerThread = ScriptTransformationWriterThread( inputIterator.map(outputProjection), input.map(_.dataType), ioschema, @@ -131,7 +131,7 @@ case class SparkScriptTransformationExec( } } -private class ScriptTransformationWriterThread( +case class ScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], ioSchema: SparkScriptIOSchema, @@ -140,15 +140,7 @@ private class ScriptTransformationWriterThread( stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration) - extends BaseScriptTransformationWriterThread( - iter, - inputSchema, - ioSchema, - outputStream, - proc, - stderrBuffer, - taskContext, - conf) { + extends BaseScriptTransformationWriterThread { setDaemon(true) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 098ffd3b7d75a..602f2cf4a6527 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -86,7 +86,7 @@ case class HiveScriptTransformationExec( // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. - val writerThread = new HiveScriptTransformationWriterThread( + val writerThread = HiveScriptTransformationWriterThread( inputIterator.map(outputProjection), input.map(_.dataType), inputSerde, @@ -208,7 +208,7 @@ case class HiveScriptTransformationExec( } } -private class HiveScriptTransformationWriterThread( +case class HiveScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], @Nullable inputSerde: AbstractSerDe, @@ -219,15 +219,7 @@ private class HiveScriptTransformationWriterThread( stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration) - extends BaseScriptTransformationWriterThread( - iter, - inputSchema, - ioSchema, - outputStream, - proc, - stderrBuffer, - taskContext, - conf) with HiveInspectors { + extends BaseScriptTransformationWriterThread with HiveInspectors { override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) From a693722ce9509da3118118a8e773f01e89a79950 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 13 Jul 2020 22:41:03 +0800 Subject: [PATCH 03/42] save --- .../BaseScriptTransformationExec.scala | 18 ++++++++++++++++-- .../spark/sql/execution/SparkSqlParser.scala | 6 +++++- .../apache/spark/sql/hive/HiveStrategies.scala | 3 ++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index aa54d93d94b7d..8217fa16148bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.OutputStream import java.nio.charset.StandardCharsets +import java.time.ZoneId import java.util.concurrent.TimeUnit import scala.util.control.NonFatal @@ -31,8 +32,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DateType, TimestampType} import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -127,7 +129,19 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) + val columnType = inputSchema(i) + val fieldValue = row.get(i, columnType) + val fieldStringValue = columnType match { + case _: DateType => + val dateFormatter = DateFormatter(ZoneId.systemDefault()) + dateFormatter.format(fieldValue.asInstanceOf[Int]) + case _: TimestampType => + TimestampFormatter.getFractionFormatter(ZoneId.systemDefault()) + .format(fieldValue.asInstanceOf[Long]) + case _ => + fieldValue.toString + } + sb.append(fieldStringValue) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3a2c673229c20..492750c60c7c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType /** @@ -713,13 +714,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } (Seq.empty, Option(name), props.toSeq, recordHandler) - case null => + case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => // Use default (serde) format. val name = conf.getConfString("hive.script.serde", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") val props = Seq("field.delim" -> "\t") val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) (Nil, Option(name), props, recordHandler) + + case null => + (Nil, None, Seq.empty, None) } val (inFormat, inSerdeClass, inSerdeProps, reader) = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index dae68df08f32e..6bac2a2203713 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -243,7 +243,8 @@ private[hive] trait HiveStrategies { object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ScriptTransformation(input, script, output, child, ioschema) => + case ScriptTransformation(input, script, output, child, ioschema) + if ioschema.inputSerdeClass.nonEmpty || ioschema.outputSerdeClass.nonEmpty => val hiveIoSchema = HiveScriptIOSchema(ioschema) HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil From 5bfa669265fac79c4387f6078ee1c62926c3a600 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 14 Jul 2020 18:31:03 +0800 Subject: [PATCH 04/42] follow comment --- .../BaseScriptTransformationExec.scala | 18 ++---------------- .../spark/sql/execution/SparkSqlParser.scala | 8 -------- .../spark/sql/execution/SparkStrategies.scala | 6 ++++-- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 8217fa16148bc..aa54d93d94b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.io.OutputStream import java.nio.charset.StandardCharsets -import java.time.ZoneId import java.util.concurrent.TimeUnit import scala.util.control.NonFatal @@ -32,9 +31,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, DateType, TimestampType} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -129,19 +127,7 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - val columnType = inputSchema(i) - val fieldValue = row.get(i, columnType) - val fieldStringValue = columnType match { - case _: DateType => - val dateFormatter = DateFormatter(ZoneId.systemDefault()) - dateFormatter.format(fieldValue.asInstanceOf[Int]) - case _: TimestampType => - TimestampFormatter.getFractionFormatter(ZoneId.systemDefault()) - .format(fieldValue.asInstanceOf[Long]) - case _ => - fieldValue.toString - } - sb.append(fieldStringValue) + sb.append(row.get(i, inputSchema(i))) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 492750c60c7c0..3ef67994fb7c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -714,14 +714,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } (Seq.empty, Option(name), props.toSeq, recordHandler) - case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => - // Use default (serde) format. - val name = conf.getConfString("hive.script.serde", - "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - val props = Seq("field.delim" -> "\t") - val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) - (Nil, Option(name), props, recordHandler) - case null => (Nil, None, Seq.empty, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d5366de2ea704..e595927ed0c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.time.ZoneId + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow @@ -38,7 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -537,7 +539,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.ScriptTransformation(input, script, output, child, ioschema) if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty => SparkScriptTransformationExec( - input, + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)), script, output, planLater(child), From ec754e270c29e992e7c8094a786fb88df1e5d850 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 14 Jul 2020 22:21:05 +0800 Subject: [PATCH 05/42] fix input and out put format --- .../BaseScriptTransformationExec.scala | 40 ++++++++++++++++++- .../SparkScriptTransformationExec.scala | 16 ++++++-- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../spark/sql/execution/SparkStrategies.scala | 6 +-- .../HiveScriptTransformationExec.scala | 26 +++++++++--- 5 files changed, 75 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index aa54d93d94b7d..035b1a276134f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -27,12 +27,15 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -87,6 +90,41 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } } + + def wrapper(data: String, dt: DataType): Any = { + dt match { + case StringType => data + case ByteType => JavaUtils.stringToBytes(data) + case IntegerType => data.toInt + case ShortType => data.toShort + case LongType => data.toLong + case FloatType => data.toFloat + case DoubleType => data.toDouble + case dt: DecimalType => BigDecimal(data) + case DateType if conf.datetimeJava8ApiEnabled => + DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull + case DateType => + DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull + case TimestampType if conf.datetimeJava8ApiEnabled => + DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull + case TimestampType => + DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull + case CalendarIntervalType => IntervalUtils.stringToInterval(UTF8String.fromString(data)) + case dataType: DataType => data + } + } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index a43ae4401da3e..a44c13fd3747d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, RedirectThread} /** @@ -67,7 +67,9 @@ case class SparkScriptTransformationExec( stderrBuffer, "Thread-ScriptTransformation-STDERR-Consumer").start() - val outputProjection = new InterpretedProjection(input, child.output) + val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + + val outputProjection = new InterpretedProjection(finalInput, child.output) // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. @@ -116,11 +118,17 @@ case class SparkScriptTransformationExec( if (!ioschema.schemaLess) { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .map(CatalystTypeConverters.convertToCatalyst)) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) } else { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .map(CatalystTypeConverters.convertToCatalyst)) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3ef67994fb7c9..5724610744aba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -713,7 +713,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { None } (Seq.empty, Option(name), props.toSeq, recordHandler) - + // SPARK-32106: When there is no definition about format, we return empty result + // then we finally execute with SparkScriptTransformationExec case null => (Nil, None, Seq.empty, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e595927ed0c55..d5366de2ea704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import java.time.ZoneId - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow @@ -40,7 +38,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.StructType /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -539,7 +537,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.ScriptTransformation(input, script, output, child, ioschema) if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty => SparkScriptTransformationExec( - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)), + input, script, output, planLater(child), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 602f2cf4a6527..ce04910002561 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** @@ -78,17 +78,25 @@ case class HiveScriptTransformationExec( stderrBuffer, "Thread-ScriptTransformation-STDERR-Consumer").start() - val outputProjection = new InterpretedProjection(input, child.output) - // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null + // We will use StringBuffer to pass data, in this case, we should cast data as string too. + val finalInput = if (inputSerde == null) { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } else { + input + } + + val outputProjection = new InterpretedProjection(finalInput, child.output) + // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. val writerThread = HiveScriptTransformationWriterThread( inputIterator.map(outputProjection), - input.map(_.dataType), + finalInput.map(_.dataType), inputSerde, inputSoi, ioschema, @@ -178,11 +186,17 @@ case class HiveScriptTransformationExec( if (!ioschema.schemaLess) { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .map(CatalystTypeConverters.convertToCatalyst)) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) } else { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .map(CatalystTypeConverters.convertToCatalyst)) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) } } else { val raw = outputSerde.deserialize(scriptOutputWritable) From a2b12a108dc6b14fe576d4f87b1304e6ffa37804 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 15 Jul 2020 18:05:29 +0800 Subject: [PATCH 06/42] follow comment --- .../BaseScriptTransformationExec.scala | 70 ++++++++++++------ .../SparkScriptTransformationExec.scala | 49 +++---------- .../HiveScriptTransformationExec.scala | 35 +-------- sql/hive/src/test/resources/test_script.py | 4 +- .../BaseScriptTransformationSuite.scala | 71 +++++++++++++------ 5 files changed, 115 insertions(+), 114 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 035b1a276134f..4645b1d559752 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution -import java.io.OutputStream +import java.io.{BufferedReader, InputStream, OutputStream} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -29,16 +30,21 @@ import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { + def input: Seq[Expression] + def script: String + def output: Seq[Attribute] + def child: SparkPlan + def ioschema: BaseScriptTransformIOSchema override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -59,10 +65,49 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } + def initProc(name: String): (OutputStream, Process, InputStream, CircularBuffer) = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + s"Thread-$name-STDERR-Consumer").start() + (outputStream, proc, inputStream, stderrBuffer) + } + def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] + def processOutputWithoutSerde(prevLine: String, reader: BufferedReader): InternalRow = { + if (!ioschema.schemaLess) { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) + } else { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .zip(output) + .map { case (data, dataType) => + CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) + }) + } + } + protected def checkFailureAndPropagate( writerThread: BaseScriptTransformationWriterThread, cause: Throwable = null, @@ -91,7 +136,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - def wrapper(data: String, dt: DataType): Any = { + protected def wrapper(data: String, dt: DataType): Any = { dt match { case StringType => data case ByteType => JavaUtils.stringToBytes(data) @@ -130,19 +175,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode { abstract class BaseScriptTransformationWriterThread extends Thread with Logging { def iter: Iterator[InternalRow] - def inputSchema: Seq[DataType] - def ioSchema: BaseScriptTransformIOSchema - def outputStream: OutputStream - def proc: Process - def stderrBuffer: CircularBuffer - def taskContext: TaskContext - def conf: Configuration setDaemon(true) @@ -219,21 +257,13 @@ abstract class BaseScriptTransformIOSchema extends Serializable { import ScriptIOSchema._ def inputRowFormat: Seq[(String, String)] - def outputRowFormat: Seq[(String, String)] - def inputSerdeClass: Option[String] - def outputSerdeClass: Option[String] - def inputSerdeProps: Seq[(String, String)] - def outputSerdeProps: Seq[(String, String)] - def recordReaderClass: Option[String] - def recordWriterClass: Option[String] - def schemaLess: Boolean val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index a44c13fd3747d..3264d8e9678ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.types._ @@ -47,25 +47,11 @@ case class SparkScriptTransformationExec( ioschema: SparkScriptIOSchema) extends BaseScriptTransformationExec { - override def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration) - : Iterator[InternalRow] = { - val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd.asJava) - - val proc = builder.start() - val inputStream = proc.getInputStream - val outputStream = proc.getOutputStream - val errorStream = proc.getErrorStream - - // In order to avoid deadlocks, we need to consume the error output of the child process. - // To avoid issues caused by large error output, we use a circular buffer to limit the amount - // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang - // that motivates this. - val stderrBuffer = new CircularBuffer(2048) - new RedirectThread( - errorStream, - stderrBuffer, - "Thread-ScriptTransformation-STDERR-Consumer").start() + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc(this.getClass.getSimpleName) val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) @@ -73,9 +59,9 @@ case class SparkScriptTransformationExec( // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. - val writerThread = ScriptTransformationWriterThread( + val writerThread = SparkScriptTransformationWriterThread( inputIterator.map(outputProjection), - input.map(_.dataType), + finalInput.map(_.dataType), ioschema, outputStream, proc, @@ -87,7 +73,6 @@ case class SparkScriptTransformationExec( val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { var curLine: String = null - val mutableRow = new SpecificInternalRow(output.map(_.dataType)) override def hasNext: Boolean = { try { @@ -115,21 +100,7 @@ case class SparkScriptTransformationExec( } val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) - } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) - } + processOutputWithoutSerde(prevLine, reader) } } @@ -139,7 +110,7 @@ case class SparkScriptTransformationExec( } } -case class ScriptTransformationWriterThread( +case class SparkScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], ioSchema: SparkScriptIOSchema, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index ce04910002561..2bfdfaf2c2343 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -60,23 +60,8 @@ case class HiveScriptTransformationExec( override def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] = { - val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd.asJava) - - val proc = builder.start() - val inputStream = proc.getInputStream - val outputStream = proc.getOutputStream - val errorStream = proc.getErrorStream - - // In order to avoid deadlocks, we need to consume the error output of the child process. - // To avoid issues caused by large error output, we use a circular buffer to limit the amount - // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang - // that motivates this. - val stderrBuffer = new CircularBuffer(2048) - new RedirectThread( - errorStream, - stderrBuffer, - "Thread-ScriptTransformation-STDERR-Consumer").start() + + val (outputStream, proc, inputStream, stderrBuffer) = initProc(this.getClass.getSimpleName) // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop @@ -183,21 +168,7 @@ case class HiveScriptTransformationExec( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) - } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) - } + processOutputWithoutSerde(prevLine, reader) } else { val raw = outputSerde.deserialize(scriptOutputWritable) scriptOutputWritable = null diff --git a/sql/hive/src/test/resources/test_script.py b/sql/hive/src/test/resources/test_script.py index 82ef7b38f0c1b..4b8e4bd108884 100644 --- a/sql/hive/src/test/resources/test_script.py +++ b/sql/hive/src/test/resources/test_script.py @@ -16,6 +16,6 @@ import sys for line in sys.stdin: - (a, b, c, d, e) = line.split('\t') - sys.stdout.write('\t'.join([a, b, c, d, e])) + (a, b, c, d, e, f, g, h, i, j) = line.split('\t') + sys.stdout.write('\t'.join([a, b, c, d, e, f, g, h, i, j])) sys.stdout.flush() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala index 4869872273832..24ea514ffacf6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.Locale import org.scalatest.Assertions._ @@ -31,10 +31,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { @@ -133,36 +135,63 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU test("SPARK-25990: TRANSFORM should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = getTestResourcePath("test_script.py") - + case class Struct(d: Int, str: String) withTempView("v") { val df = Seq( - (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), - (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), - (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) - ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), + Map("a" -> 1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), + Map("b" -> 2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), + Map("c" -> 3)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, struct('a, 'b).as("j")) + // Note column d's data type is Decimal(38, 18) df.createTempView("v") + assert(spark.table("v").schema == + StructType(Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true), + StructField("c", DoubleType, false), + StructField("d", DecimalType(38, 18), true), + StructField("e", TimestampType, true), + StructField("f", DateType, true), + StructField("g", CalendarIntervalType, true), + StructField("h", ArrayType(IntegerType, false), true), + StructField("i", MapType(StringType, IntegerType, false), true), + StructField("j", StructType( + Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true))), false)))) + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string val query = sql( s""" |SELECT - |TRANSFORM(a, b, c, d, e) - |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |TRANSFORM(a, b, c, d, e, f, g, h, i, j) + |USING 'python $scriptFilePath' + |AS ( + | a int, + | b string, + | c double, + | d decimal(1, 0), + | e timestamp, + | f date, + | g interval, + | h string, + | i string, + | j string) |FROM v """.stripMargin) - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (HiveUtils.isHive23) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - decimalToString('d), - 'e.cast("string")).collect()) + 'a, 'b, 'c, 'd, 'e, 'f, 'g, + 'h.cast("string"), + 'i.cast("string"), + 'j.cast("string")).collect()) } } From c3dc66bdda87ffb72b0debe2e7821f0ad53648d4 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 15 Jul 2020 18:21:06 +0800 Subject: [PATCH 07/42] follow comment --- .../BaseScriptTransformationExec.scala | 62 +++++++++---------- .../BaseScriptTransformationSuite.scala | 2 - 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 4645b1d559752..85072a2eccc85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -94,17 +94,13 @@ trait BaseScriptTransformationExec extends UnaryExecNode { if (!ioschema.schemaLess) { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) + .zip(fieldWriters) + .map { case (data, writer) => writer(data) }) } else { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .zip(output) - .map { case (data, dataType) => - CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType)) - }) + .zip(fieldWriters) + .map { case (data, writer) => writer(data) }) } } @@ -136,38 +132,40 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - protected def wrapper(data: String, dt: DataType): Any = { - dt match { - case StringType => data - case ByteType => JavaUtils.stringToBytes(data) - case IntegerType => data.toInt - case ShortType => data.toShort - case LongType => data.toLong - case FloatType => data.toFloat - case DoubleType => data.toDouble - case dt: DecimalType => BigDecimal(data) - case DateType if conf.datetimeJava8ApiEnabled => - DateTimeUtils.stringToDate( + private lazy val fieldWriters: Seq[String => Any] = output.map { attr => + val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) + attr.dataType match { + case StringType => (data: String) => converter(data) + case ByteType => (data: String) => converter(JavaUtils.stringToBytes(data)) + case IntegerType => (data: String) => converter(data.toInt) + case ShortType => (data: String) => converter(data.toShort) + case LongType => (data: String) => converter(data.toLong) + case FloatType => (data: String) => converter(data.toFloat) + case DoubleType => (data: String) => converter(data.toDouble) + case dt: DecimalType => (data: String) => converter(BigDecimal(data)) + case DateType if conf.datetimeJava8ApiEnabled => (data: String) => + converter(DateTimeUtils.stringToDate( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull - case DateType => - DateTimeUtils.stringToDate( + .map(DateTimeUtils.daysToLocalDate).orNull) + case DateType => (data: String) => + converter(DateTimeUtils.stringToDate( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull - case TimestampType if conf.datetimeJava8ApiEnabled => - DateTimeUtils.stringToTimestamp( + .map(DateTimeUtils.toJavaDate).orNull) + case TimestampType if conf.datetimeJava8ApiEnabled => (data: String) => + converter(DateTimeUtils.stringToTimestamp( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull - case TimestampType => - DateTimeUtils.stringToTimestamp( + .map(DateTimeUtils.microsToInstant).orNull) + case TimestampType => (data: String) => + converter(DateTimeUtils.stringToTimestamp( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull - case CalendarIntervalType => IntervalUtils.stringToInterval(UTF8String.fromString(data)) - case dataType: DataType => data + .map(DateTimeUtils.toJavaTimestamp).orNull) + case CalendarIntervalType => (data: String) => + converter(IntervalUtils.stringToInterval(UTF8String.fromString(data))) + case dataType: DataType => (data: String) => converter(data) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala index 24ea514ffacf6..1719865d85f1d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -26,13 +26,11 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ From cb19b7b96e194dcd2951959a06f0639062ac04db Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 17 Jul 2020 11:03:25 +0800 Subject: [PATCH 08/42] follow comment --- .../BaseScriptTransformationExec.scala | 51 +++-- .../SparkScriptTransformationExec.scala | 41 +--- .../spark/sql/execution/SparkSqlParser.scala | 11 +- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 7 +- .../HiveScriptTransformationExec.scala | 211 ++++++++---------- sql/hive/src/test/resources/test_script.py | 4 +- .../src/test/resources/test_spark_script.py | 21 ++ .../BaseScriptTransformationSuite.scala | 80 +++---- .../HiveScriptTransformationSuite.scala | 11 +- .../SparkScriptTransformationSuite.scala | 87 +++++++- 11 files changed, 285 insertions(+), 241 deletions(-) create mode 100644 sql/hive/src/test/resources/test_spark_script.py diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 85072a2eccc85..1d7899ccf3a3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf @@ -44,7 +45,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def script: String def output: Seq[Attribute] def child: SparkPlan - def ioschema: BaseScriptTransformIOSchema + def ioschema: ScriptTransformationIOSchema override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -65,7 +66,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - def initProc(name: String): (OutputStream, Process, InputStream, CircularBuffer) = { + protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = { val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd.asJava) @@ -82,15 +83,15 @@ trait BaseScriptTransformationExec extends UnaryExecNode { new RedirectThread( errorStream, stderrBuffer, - s"Thread-$name-STDERR-Consumer").start() + s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start() (outputStream, proc, inputStream, stderrBuffer) } - def processIterator( + protected def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] - def processOutputWithoutSerde(prevLine: String, reader: BufferedReader): InternalRow = { + protected def processOutputWithoutSerde(prevLine: String, reader: BufferedReader): InternalRow = { if (!ioschema.schemaLess) { new GenericInternalRow( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) @@ -174,7 +175,7 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging def iter: Iterator[InternalRow] def inputSchema: Seq[DataType] - def ioSchema: BaseScriptTransformIOSchema + def ioSchema: ScriptTransformationIOSchema def outputStream: OutputStream def proc: Process def stderrBuffer: CircularBuffer @@ -251,26 +252,38 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging /** * The wrapper class of input and output schema properties */ -abstract class BaseScriptTransformIOSchema extends Serializable { - import ScriptIOSchema._ - - def inputRowFormat: Seq[(String, String)] - def outputRowFormat: Seq[(String, String)] - def inputSerdeClass: Option[String] - def outputSerdeClass: Option[String] - def inputSerdeProps: Seq[(String, String)] - def outputSerdeProps: Seq[(String, String)] - def recordReaderClass: Option[String] - def recordWriterClass: Option[String] - def schemaLess: Boolean +case class ScriptTransformationIOSchema( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) extends Serializable { + import ScriptTransformationIOSchema._ val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) } -object ScriptIOSchema { +object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), ("TOK_TABLEROWFORMATLINES", "\n") ) + + def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = { + ScriptTransformationIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 3264d8e9678ce..4909feae20ad5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io._ import java.nio.charset.StandardCharsets -import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -28,9 +27,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.types._ -import org.apache.spark.util.{CircularBuffer, RedirectThread} +import org.apache.spark.util.CircularBuffer /** * Transforms the input by forking and running the specified script. @@ -44,14 +42,14 @@ case class SparkScriptTransformationExec( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: SparkScriptIOSchema) + ioschema: ScriptTransformationIOSchema) extends BaseScriptTransformationExec { override def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] = { - val (outputStream, proc, inputStream, stderrBuffer) = initProc(this.getClass.getSimpleName) + val (outputStream, proc, inputStream, stderrBuffer) = initProc val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) @@ -113,7 +111,7 @@ case class SparkScriptTransformationExec( case class SparkScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], - ioSchema: SparkScriptIOSchema, + ioSchema: ScriptTransformationIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, @@ -121,38 +119,7 @@ case class SparkScriptTransformationWriterThread( conf: Configuration) extends BaseScriptTransformationWriterThread { - setDaemon(true) - override def processRows(): Unit = { processRowsWithoutSerde() } } - -object SparkScriptIOSchema { - def apply(input: ScriptInputOutputSchema): SparkScriptIOSchema = { - SparkScriptIOSchema( - input.inputRowFormat, - input.outputRowFormat, - input.inputSerdeClass, - input.outputSerdeClass, - input.inputSerdeProps, - input.outputSerdeProps, - input.recordReaderClass, - input.recordWriterClass, - input.schemaLess) - } -} - -/** - * The wrapper class of Spark script transformation input and output schema properties - */ -case class SparkScriptIOSchema ( - inputRowFormat: Seq[(String, String)], - outputRowFormat: Seq[(String, String)], - inputSerdeClass: Option[String], - outputSerdeClass: Option[String], - inputSerdeProps: Seq[(String, String)], - outputSerdeProps: Seq[(String, String)], - recordReaderClass: Option[String], - recordWriterClass: Option[String], - schemaLess: Boolean) extends BaseScriptTransformIOSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 5724610744aba..37bd3022ba4bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -713,8 +713,17 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { None } (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => + // Use default (serde) format. + val name = conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + val props = Seq("field.delim" -> "\t") + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) + (Nil, Option(name), props, recordHandler) + // SPARK-32106: When there is no definition about format, we return empty result - // then we finally execute with SparkScriptTransformationExec + // to use a built-in default Serde in SparkScriptTransformationExec. case null => (Nil, None, Seq.empty, None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d5366de2ea704..21ddea51df4a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -541,7 +541,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { script, output, planLater(child), - SparkScriptIOSchema(ioschema) + ScriptTransformationIOSchema(ioschema) ) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6bac2a2203713..97e1dee5913a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec} +import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -243,9 +243,8 @@ private[hive] trait HiveStrategies { object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ScriptTransformation(input, script, output, child, ioschema) - if ioschema.inputSerdeClass.nonEmpty || ioschema.outputSerdeClass.nonEmpty => - val hiveIoSchema = HiveScriptIOSchema(ioschema) + case ScriptTransformation(input, script, output, child, ioschema) => + val hiveIoSchema = ScriptTransformationIOSchema(ioschema) HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 2bfdfaf2c2343..4f17d1f40afdc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -33,14 +33,13 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.{DataType, StringType} -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,18 +53,91 @@ case class HiveScriptTransformationExec( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema) - extends BaseScriptTransformationExec { + ioschema: ScriptTransformationIOSchema) + extends BaseScriptTransformationExec with HiveInspectors { + + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) + (serde, objectInspector) + } + } + + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } + } + + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) + (columns, columnTypes) + } + + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { + + val serde = Utils.classForName[AbstractSerDe](serdeClassName).getConstructor(). + newInstance() + + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + + val properties = new Properties() + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + propsMap.foreach { case (k, v) => properties.put(k, v) } + serde.initialize(null, properties) + + serde + } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + ioschema.recordReaderClass.map { klass => + val instance = Utils.classForName[RecordReader](klass).getConstructor(). + newInstance() + val props = new Properties() + // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + ioschema.outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + ioschema.recordWriterClass.map { klass => + val instance = Utils.classForName[RecordWriter](klass).getConstructor(). + newInstance() + instance.initialize(outputStream, conf) + instance + } + } override def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] = { - val (outputStream, proc, inputStream, stderrBuffer) = initProc(this.getClass.getSimpleName) + val (outputStream, proc, inputStream, stderrBuffer) = initProc // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop - @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null // We will use StringBuffer to pass data, in this case, we should cast data as string too. @@ -86,6 +158,7 @@ case class HiveScriptTransformationExec( inputSoi, ioschema, outputStream, + recordWriter, proc, stderrBuffer, TaskContext.get(), @@ -95,7 +168,7 @@ case class HiveScriptTransformationExec( // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop @Nullable val (outputSerde, outputSoi) = { - ioschema.initOutputSerDe(output).getOrElse((null, null)) + initOutputSerDe(output).getOrElse((null, null)) } val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) @@ -103,8 +176,7 @@ case class HiveScriptTransformationExec( var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) - @Nullable val scriptOutputReader = - ioschema.recordReader(scriptOutputStream, hadoopConf).orNull + @Nullable val scriptOutputReader = recordReader(scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { @@ -165,11 +237,17 @@ case class HiveScriptTransformationExec( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { + nextRow() + } + + val nextRow: () => InternalRow = if (outputSerde == null) { + () => { val prevLine = curLine curLine = reader.readLine() processOutputWithoutSerde(prevLine, reader) - } else { + } + } else { + () => { val raw = outputSerde.deserialize(scriptOutputWritable) scriptOutputWritable = null val dataList = outputSoi.getStructFieldsDataAsList(raw) @@ -198,8 +276,9 @@ case class HiveScriptTransformationWriterThread( inputSchema: Seq[DataType], @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: StructObjectInspector, - ioSchema: HiveScriptIOSchema, + ioSchema: ScriptTransformationIOSchema, outputStream: OutputStream, + recordWriter: (OutputStream, Configuration) => Option[RecordWriter], proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, @@ -208,7 +287,7 @@ case class HiveScriptTransformationWriterThread( override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) - @Nullable val scriptInputWriter = ioSchema.recordWriter(dataOutputStream, conf).orNull + @Nullable val scriptInputWriter = recordWriter(dataOutputStream, conf).orNull if (inputSerde == null) { processRowsWithoutSerde() @@ -235,107 +314,3 @@ case class HiveScriptTransformationWriterThread( } } } - -object HiveScriptIOSchema { - def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = { - HiveScriptIOSchema( - input.inputRowFormat, - input.outputRowFormat, - input.inputSerdeClass, - input.outputSerdeClass, - input.inputSerdeProps, - input.outputSerdeProps, - input.recordReaderClass, - input.recordWriterClass, - input.schemaLess) - } -} - -/** - * The wrapper class of Hive script transformation input and output schema properties - */ -case class HiveScriptIOSchema ( - inputRowFormat: Seq[(String, String)], - outputRowFormat: Seq[(String, String)], - inputSerdeClass: Option[String], - outputSerdeClass: Option[String], - inputSerdeProps: Seq[(String, String)], - outputSerdeProps: Seq[(String, String)], - recordReaderClass: Option[String], - recordWriterClass: Option[String], - schemaLess: Boolean) - extends BaseScriptTransformIOSchema with HiveInspectors { - - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { - inputSerdeClass.map { serdeClass => - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) - val fieldObjectInspectors = columnTypes.map(toInspector) - val objectInspector = ObjectInspectorFactory - .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) - (serde, objectInspector) - } - } - - def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { - outputSerdeClass.map { serdeClass => - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) - val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] - (serde, structObjectInspector) - } - } - - private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") - val columnTypes = attrs.map(_.dataType) - (columns, columnTypes) - } - - private def initSerDe( - serdeClassName: String, - columns: Seq[String], - columnTypes: Seq[DataType], - serdeProps: Seq[(String, String)]): AbstractSerDe = { - - val serde = Utils.classForName[AbstractSerDe](serdeClassName).getConstructor(). - newInstance() - - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - - var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - - val properties = new Properties() - // Can not use properties.putAll(propsMap.asJava) in scala-2.12 - // See https://github.com/scala/bug/issues/10418 - propsMap.foreach { case (k, v) => properties.put(k, v) } - serde.initialize(null, properties) - - serde - } - - def recordReader( - inputStream: InputStream, - conf: Configuration): Option[RecordReader] = { - recordReaderClass.map { klass => - val instance = Utils.classForName[RecordReader](klass).getConstructor(). - newInstance() - val props = new Properties() - // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 - // See https://github.com/scala/bug/issues/10418 - outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } - instance.initialize(inputStream, conf, props) - instance - } - } - - def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { - recordWriterClass.map { klass => - val instance = Utils.classForName[RecordWriter](klass).getConstructor(). - newInstance() - instance.initialize(outputStream, conf) - instance - } - } -} diff --git a/sql/hive/src/test/resources/test_script.py b/sql/hive/src/test/resources/test_script.py index 4b8e4bd108884..82ef7b38f0c1b 100644 --- a/sql/hive/src/test/resources/test_script.py +++ b/sql/hive/src/test/resources/test_script.py @@ -16,6 +16,6 @@ import sys for line in sys.stdin: - (a, b, c, d, e, f, g, h, i, j) = line.split('\t') - sys.stdout.write('\t'.join([a, b, c, d, e, f, g, h, i, j])) + (a, b, c, d, e) = line.split('\t') + sys.stdout.write('\t'.join([a, b, c, d, e])) sys.stdout.flush() diff --git a/sql/hive/src/test/resources/test_spark_script.py b/sql/hive/src/test/resources/test_spark_script.py new file mode 100644 index 0000000000000..4b8e4bd108884 --- /dev/null +++ b/sql/hive/src/test/resources/test_spark_script.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +for line in sys.stdin: + (a, b, c, d, e, f, g, h, i, j) = line.split('\t') + sys.stdout.write('\t'.join([a, b, c, d, e, f, g, h, i, j])) + sys.stdout.flush() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala index 1719865d85f1d..b0af624ea1301 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.Locale import org.scalatest.Assertions._ @@ -26,24 +26,25 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { def scriptType: String + def isHive23OrSpark: Boolean = true + import spark.implicits._ - var noSerdeIOSchema: BaseScriptTransformIOSchema = _ + var noSerdeIOSchema: ScriptTransformationIOSchema = _ private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ @@ -70,21 +71,21 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: BaseScriptTransformIOSchema): BaseScriptTransformationExec = { + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { scriptType.toUpperCase(Locale.ROOT) match { case "SPARK" => new SparkScriptTransformationExec( input = input, script = script, output = output, child = child, - ioschema = ioschema.asInstanceOf[SparkScriptIOSchema] + ioschema = ioschema ) case "HIVE" => new HiveScriptTransformationExec( input = input, script = script, output = output, child = child, - ioschema = ioschema.asInstanceOf[HiveScriptIOSchema] + ioschema = ioschema ) case _ => throw new TestFailedException( "Test class implement from BaseScriptTransformationSuite" + @@ -133,63 +134,36 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU test("SPARK-25990: TRANSFORM should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = getTestResourcePath("test_script.py") - case class Struct(d: Int, str: String) + withTempView("v") { val df = Seq( - (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1), - new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1)), - (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2), - new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2)), - (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3), - new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, struct('a, 'b).as("j")) - // Note column d's data type is Decimal(38, 18) + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) df.createTempView("v") - assert(spark.table("v").schema == - StructType(Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true), - StructField("c", DoubleType, false), - StructField("d", DecimalType(38, 18), true), - StructField("e", TimestampType, true), - StructField("f", DateType, true), - StructField("g", CalendarIntervalType, true), - StructField("h", ArrayType(IntegerType, false), true), - StructField("i", MapType(StringType, IntegerType, false), true), - StructField("j", StructType( - Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true))), false)))) - - // Can't support convert script output data to ArrayType/MapType/StructType now, - // return these column still as string val query = sql( s""" |SELECT - |TRANSFORM(a, b, c, d, e, f, g, h, i, j) - |USING 'python $scriptFilePath' - |AS ( - | a int, - | b string, - | c double, - | d decimal(1, 0), - | e timestamp, - | f date, - | g interval, - | h string, - | i string, - | j string) + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) |FROM v """.stripMargin) + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (isHive23OrSpark) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } checkAnswer(query, identity, df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, - 'h.cast("string"), - 'i.cast("string"), - 'j.cast("string")).collect()) + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 3caeca4a0eb30..b58f5ae60f01c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -22,15 +22,18 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{BaseScriptTransformIOSchema, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.types.StringType class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { override def scriptType: String = "HIVE" + override def isHive23OrSpark: Boolean = HiveUtils.isHive23 + import spark.implicits._ - noSerdeIOSchema = HiveScriptIOSchema( + noSerdeIOSchema = ScriptTransformationIOSchema( inputRowFormat = Seq.empty, outputRowFormat = Seq.empty, inputSerdeClass = None, @@ -42,8 +45,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { schemaLess = false ) - private val serdeIOSchema: BaseScriptTransformIOSchema = { - noSerdeIOSchema.asInstanceOf[HiveScriptIOSchema].copy( + private val serdeIOSchema: ScriptTransformationIOSchema = { + noSerdeIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala index 372e2a8054cda..586ca1b98c304 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala @@ -17,13 +17,22 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.execution.SparkScriptIOSchema +import java.sql.{Date, Timestamp} + +import org.apache.spark.TestUtils +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan} +import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { + import spark.implicits._ + override def scriptType: String = "SPARK" - noSerdeIOSchema = SparkScriptIOSchema( + noSerdeIOSchema = ScriptTransformationIOSchema( inputRowFormat = Seq.empty, outputRowFormat = Seq.empty, inputSerdeClass = None, @@ -34,4 +43,78 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { recordWriterClass = None, schemaLess = false ) + + test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_spark_script.py") + case class Struct(d: Int, str: String) + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), + Map("a" -> 1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), + Map("b" -> 2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), + Map("c" -> 3)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, struct('a, 'b).as("j")) + // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + assert(spark.table("v").schema == + StructType(Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true), + StructField("c", DoubleType, false), + StructField("d", DecimalType(38, 18), true), + StructField("e", TimestampType, true), + StructField("f", DateType, true), + StructField("g", CalendarIntervalType, true), + StructField("h", ArrayType(IntegerType, false), true), + StructField("i", MapType(StringType, IntegerType, false), true), + StructField("j", StructType( + Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true))), false)))) + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(1, 0))(), + AttributeReference("e", TimestampType)(), + AttributeReference("f", DateType)(), + AttributeReference("g", CalendarIntervalType)(), + AttributeReference("h", StringType)(), + AttributeReference("i", StringType)(), + AttributeReference("j", StringType)() + ), + child = child, + ioschema = noSerdeIOSchema + ), + df.select( + 'a, 'b, 'c, 'd, 'e, 'f, 'g, + 'h.cast("string"), + 'i.cast("string"), + 'j.cast("string")).collect()) + } + } } From ce8a0a547cb6574acb30f35fbad74ed72bbe5a9f Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 17 Jul 2020 11:37:15 +0800 Subject: [PATCH 09/42] fix bytetype and add it in UT --- .../BaseScriptTransformationExec.scala | 3 +- .../src/test/resources/test_spark_script.py | 21 -------- .../SparkScriptTransformationSuite.scala | 49 ++++++++++--------- 3 files changed, 26 insertions(+), 47 deletions(-) delete mode 100644 sql/hive/src/test/resources/test_spark_script.py diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 1d7899ccf3a3e..bbb4db3868c12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -28,7 +28,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, UnsafeProjection} @@ -137,7 +136,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) attr.dataType match { case StringType => (data: String) => converter(data) - case ByteType => (data: String) => converter(JavaUtils.stringToBytes(data)) + case ByteType => (data: String) => converter(data.toByte) case IntegerType => (data: String) => converter(data.toInt) case ShortType => (data: String) => converter(data.toShort) case LongType => (data: String) => converter(data.toLong) diff --git a/sql/hive/src/test/resources/test_spark_script.py b/sql/hive/src/test/resources/test_spark_script.py deleted file mode 100644 index 4b8e4bd108884..0000000000000 --- a/sql/hive/src/test/resources/test_spark_script.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import sys -for line in sys.stdin: - (a, b, c, d, e, f, g, h, i, j) = line.split('\t') - sys.stdout.write('\t'.join([a, b, c, d, e, f, g, h, i, j])) - sys.stdout.flush() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala index 586ca1b98c304..2bfa70e6410f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala @@ -46,21 +46,20 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) - val scriptFilePath = getTestResourcePath("test_spark_script.py") case class Struct(d: Int, str: String) withTempView("v") { val df = Seq( - (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1), + (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), Map("a" -> 1)), - (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2), + (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), Map("b" -> 2)), - (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3), + (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), Map("c" -> 3)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, struct('a, 'b).as("j")) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, struct('a, 'b).as("k")) // Note column d's data type is Decimal(38, 18) df.createTempView("v") @@ -68,13 +67,14 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, true), StructField("c", DoubleType, false), - StructField("d", DecimalType(38, 18), true), - StructField("e", TimestampType, true), - StructField("f", DateType, true), - StructField("g", CalendarIntervalType, true), - StructField("h", ArrayType(IntegerType, false), true), - StructField("i", MapType(StringType, IntegerType, false), true), - StructField("j", StructType( + StructField("d", ByteType, false), + StructField("e", DecimalType(38, 18), true), + StructField("f", TimestampType, true), + StructField("g", DateType, true), + StructField("h", CalendarIntervalType, true), + StructField("i", ArrayType(IntegerType, false), true), + StructField("j", MapType(StringType, IntegerType, false), true), + StructField("k", StructType( Seq(StructField("a", IntegerType, false), StructField("b", StringType, true))), false)))) @@ -93,28 +93,29 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { df.col("g").expr, df.col("h").expr, df.col("i").expr, - df.col("j").expr), + df.col("j").expr, + df.col("k").expr), script = "cat", output = Seq( AttributeReference("a", IntegerType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(1, 0))(), - AttributeReference("e", TimestampType)(), - AttributeReference("f", DateType)(), - AttributeReference("g", CalendarIntervalType)(), - AttributeReference("h", StringType)(), + AttributeReference("d", ByteType)(), + AttributeReference("e", DecimalType(38, 18))(), + AttributeReference("f", TimestampType)(), + AttributeReference("g", DateType)(), + AttributeReference("h", CalendarIntervalType)(), AttributeReference("i", StringType)(), - AttributeReference("j", StringType)() - ), + AttributeReference("j", StringType)(), + AttributeReference("k", StringType)()), child = child, ioschema = noSerdeIOSchema ), df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, - 'h.cast("string"), + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i.cast("string"), - 'j.cast("string")).collect()) + 'j.cast("string"), + 'k.cast("string")).collect()) } } } From d37ef8673f66705b185f7fdf7ad001fdc7f9f504 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 17 Jul 2020 17:59:07 +0800 Subject: [PATCH 10/42] format code --- .../BaseScriptTransformationExec.scala | 28 +++++++++++-------- .../BaseScriptTransformationSuite.scala | 2 +- .../HiveScriptTransformationSuite.scala | 12 -------- .../SparkScriptTransformationSuite.scala | 14 +--------- 4 files changed, 19 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index bbb4db3868c12..61d0360249781 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -91,17 +91,11 @@ trait BaseScriptTransformationExec extends UnaryExecNode { hadoopConf: Configuration): Iterator[InternalRow] protected def processOutputWithoutSerde(prevLine: String, reader: BufferedReader): InternalRow = { - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .zip(fieldWriters) - .map { case (data, writer) => writer(data) }) - } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .zip(fieldWriters) - .map { case (data, writer) => writer(data) }) - } + val limit = if (ioschema.schemaLess) 2 else 0 + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), limit) + .zip(fieldWriters) + .map { case (data, writer) => writer(data) }) } protected def checkFailureAndPropagate( @@ -273,6 +267,18 @@ object ScriptTransformationIOSchema { ("TOK_TABLEROWFORMATLINES", "\n") ) + val defaultIOSchema = ScriptTransformationIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, + schemaLess = false + ) + def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = { ScriptTransformationIOSchema( input.inputRowFormat, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala index b0af624ea1301..8398b82da0c2e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -44,7 +44,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU import spark.implicits._ - var noSerdeIOSchema: ScriptTransformationIOSchema = _ + var noSerdeIOSchema: ScriptTransformationIOSchema = ScriptTransformationIOSchema.defaultIOSchema private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index b58f5ae60f01c..11928fbc2ef0f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -33,18 +33,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { import spark.implicits._ - noSerdeIOSchema = ScriptTransformationIOSchema( - inputRowFormat = Seq.empty, - outputRowFormat = Seq.empty, - inputSerdeClass = None, - outputSerdeClass = None, - inputSerdeProps = Seq.empty, - outputSerdeProps = Seq.empty, - recordReaderClass = None, - recordWriterClass = None, - schemaLess = false - ) - private val serdeIOSchema: ScriptTransformationIOSchema = { noSerdeIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala index 2bfa70e6410f2..381679075c0f1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.TestUtils import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions.struct import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -32,18 +32,6 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { override def scriptType: String = "SPARK" - noSerdeIOSchema = ScriptTransformationIOSchema( - inputRowFormat = Seq.empty, - outputRowFormat = Seq.empty, - inputSerdeClass = None, - outputSerdeClass = None, - inputSerdeProps = Seq.empty, - outputSerdeProps = Seq.empty, - recordReaderClass = None, - recordWriterClass = None, - schemaLess = false - ) - test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) case class Struct(d: Int, str: String) From fce25ffb97289ec236f0de8483c2465ea1a933fe Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Jul 2020 22:11:05 +0900 Subject: [PATCH 11/42] Fix --- .../BaseScriptTransformationSuite.scala | 51 +++++-------------- .../SparkScriptTransformationSuite.scala | 29 ++++++++--- .../TestUncaughtExceptionHandler.scala | 2 +- .../HiveScriptTransformationSuite.scala | 27 +++++++--- .../sql/hive/execution/SQLQuerySuite.scala | 1 + 5 files changed, 57 insertions(+), 53 deletions(-) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/execution/BaseScriptTransformationSuite.scala (83%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/execution/SparkScriptTransformationSuite.scala (84%) rename sql/{hive/src/test/scala/org/apache/spark/sql/hive => core/src/test/scala/org/apache/spark/sql}/execution/TestUncaughtExceptionHandler.scala (96%) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala similarity index 83% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 8398b82da0c2e..26c08c3f513c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import java.sql.Timestamp -import java.util.Locale import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach @@ -30,26 +29,18 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils - with TestHiveSingleton with BeforeAndAfterEach { + with BeforeAndAfterEach { + import testImplicits._ + import ScriptTransformationIOSchema._ - def scriptType: String - - def isHive23OrSpark: Boolean = true - - import spark.implicits._ - - var noSerdeIOSchema: ScriptTransformationIOSchema = ScriptTransformationIOSchema.defaultIOSchema + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ - protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler - protected override def beforeAll(): Unit = { super.beforeAll() defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler @@ -66,32 +57,14 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU uncaughtExceptionHandler.cleanStatus() } + def isHive23OrSpark: Boolean + def createScriptTransformationExec( input: Seq[Expression], script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { - scriptType.toUpperCase(Locale.ROOT) match { - case "SPARK" => new SparkScriptTransformationExec( - input = input, - script = script, - output = output, - child = child, - ioschema = ioschema - ) - case "HIVE" => new HiveScriptTransformationExec( - input = input, - script = script, - output = output, - child = child, - ioschema = ioschema - ) - case _ => throw new TestFailedException( - "Test class implement from BaseScriptTransformationSuite" + - " should override method `scriptType` to Spark or Hive", 0) - } - } + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec test("cat without SerDe") { assume(TestUtils.testCommandAvailable("/bin/bash")) @@ -104,7 +77,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) assert(uncaughtExceptionHandler.exception.isEmpty) @@ -122,7 +95,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) } @@ -178,8 +151,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) + ioschema = defaultIOSchema) + SparkPlanTest.executePlan(plan, spark.sqlContext) } assert(e.getMessage.contains("Subprocess exited with status")) assert(uncaughtExceptionHandler.exception.isEmpty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala similarity index 84% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 381679075c0f1..1abf298a6123e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -15,22 +15,37 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import java.sql.{Date, Timestamp} import org.apache.spark.TestUtils -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { +class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { + import testImplicits._ + import ScriptTransformationIOSchema._ - import spark.implicits._ + override def isHive23OrSpark: Boolean = true - override def scriptType: String = "SPARK" + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + SparkScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) + } test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) @@ -97,7 +112,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { AttributeReference("j", StringType)(), AttributeReference("k", StringType)()), child = child, - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), df.select( 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala similarity index 96% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala index 681eb4e255dbc..360f4658345e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution class TestUncaughtExceptionHandler extends Thread.UncaughtExceptionHandler { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 11928fbc2ef0f..7ba1deb101a65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -21,20 +21,35 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TestUtils} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { - override def scriptType: String = "HIVE" +class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with TestHiveSingleton { + import testImplicits._ + import ScriptTransformationIOSchema._ override def isHive23OrSpark: Boolean = HiveUtils.isHive23 - import spark.implicits._ + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + HiveScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) + } private val serdeIOSchema: ScriptTransformationIOSchema = { - noSerdeIOSchema.copy( + defaultIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 920f6385f8e19..24b9d25ed94f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Functio import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.execution.TestUncaughtExceptionHandler import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.command.{FunctionsCommand, LoadDataCommand} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} From f3e05c6e1ea1e195ff2cbc9e3aa70c45cf9cc79f Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Jul 2020 22:50:38 +0900 Subject: [PATCH 12/42] Fix --- .../BaseScriptTransformationExec.scala | 42 +++++- .../SparkScriptTransformationExec.scala | 38 +----- .../HiveScriptTransformationExec.scala | 126 +++++++++--------- 3 files changed, 104 insertions(+), 102 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 61d0360249781..e243acd7def80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{BufferedReader, InputStream, OutputStream} +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -98,6 +98,46 @@ trait BaseScriptTransformationExec extends UnaryExecNode { .map { case (data, writer) => writer(data) }) } + protected def createOutputIteratorWithoutSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer): Iterator[InternalRow] = { + new Iterator[InternalRow] { + var curLine: String = null + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + processOutputWithoutSerde(prevLine, reader) + } + } + } + protected def checkFailureAndPropagate( writerThread: BaseScriptTransformationWriterThread, cause: Throwable = null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 4909feae20ad5..103eaf869039d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -18,9 +18,6 @@ package org.apache.spark.sql.execution import java.io._ -import java.nio.charset.StandardCharsets - -import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -68,39 +65,8 @@ case class SparkScriptTransformationExec( hadoopConf ) - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { - var curLine: String = null - - override def hasNext: Boolean = { - try { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - true - } catch { - case NonFatal(e) => - // If this exception is due to abrupt / unclean termination of `proc`, - // then detect it and propagate a better exception message for end users - checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) - - throw e - } - } - - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException - } - val prevLine = curLine - curLine = reader.readLine() - processOutputWithoutSerde(prevLine, reader) - } - } + val outputIterator = createOutputIteratorWithoutSerde( + writerThread, inputStream, proc, stderrBuffer) writerThread.start() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 4f17d1f40afdc..37a3789205b5d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io._ -import java.nio.charset.StandardCharsets import java.util.Properties import javax.annotation.Nullable @@ -129,61 +128,22 @@ case class HiveScriptTransformationExec( } } - override def processIterator( - inputIterator: Iterator[InternalRow], + private def createOtputIteratorWithSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer, + outputSerde: AbstractSerDe, + outputSoi: StructObjectInspector, hadoopConf: Configuration): Iterator[InternalRow] = { - - val (outputStream, proc, inputStream, stderrBuffer) = initProc - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) - - // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null - // We will use StringBuffer to pass data, in this case, we should cast data as string too. - val finalInput = if (inputSerde == null) { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) - } else { - input - } - - val outputProjection = new InterpretedProjection(finalInput, child.output) - - // This new thread will consume the ScriptTransformation's input rows and write them to the - // external process. That process's output will be read by this current thread. - val writerThread = HiveScriptTransformationWriterThread( - inputIterator.map(outputProjection), - finalInput.map(_.dataType), - inputSerde, - inputSoi, - ioschema, - outputStream, - recordWriter, - proc, - stderrBuffer, - TaskContext.get(), - hadoopConf - ) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - initOutputSerDe(output).getOrElse((null, null)) - } - - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) @Nullable val scriptOutputReader = recordReader(scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null - val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().getConstructor().newInstance() - } else { - null - } + val reusedWritableObject = outputSerde.getSerializedClass.getConstructor().newInstance() val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient @@ -191,15 +151,7 @@ case class HiveScriptTransformationExec( override def hasNext: Boolean = { try { - if (outputSerde == null) { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - } else if (scriptOutputWritable == null) { + if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject if (scriptOutputReader != null) { @@ -240,13 +192,7 @@ case class HiveScriptTransformationExec( nextRow() } - val nextRow: () => InternalRow = if (outputSerde == null) { - () => { - val prevLine = curLine - curLine = reader.readLine() - processOutputWithoutSerde(prevLine, reader) - } - } else { + val nextRow: () => InternalRow = { () => { val raw = outputSerde.deserialize(scriptOutputWritable) scriptOutputWritable = null @@ -264,6 +210,56 @@ case class HiveScriptTransformationExec( } } } + } + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) + + // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null + // We will use StringBuffer to pass data, in this case, we should cast data as string too. + val finalInput = if (inputSerde == null) { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } else { + input + } + + val outputProjection = new InterpretedProjection(finalInput, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = HiveScriptTransformationWriterThread( + inputIterator.map(outputProjection), + finalInput.map(_.dataType), + inputSerde, + inputSoi, + ioschema, + outputStream, + recordWriter, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + initOutputSerDe(output).getOrElse((null, null)) + } + + val outputIterator = if (outputSerde == null) { + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) + } else { + createOtputIteratorWithSerde( + writerThread, inputStream, proc, stderrBuffer, outputSerde, outputSoi, hadoopConf) + } writerThread.start() From 04684a89f2080dddb7b39ad5564eb42f7b0d00b9 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 12:44:41 +0800 Subject: [PATCH 13/42] fix UT and follow comment --- .../BaseScriptTransformationExec.scala | 16 +++---- .../SparkScriptTransformationExec.scala | 4 +- .../src/test/resources/test_script.py | 0 .../HiveScriptTransformationExec.scala | 44 +++++++++---------- 4 files changed, 29 insertions(+), 35 deletions(-) rename sql/{hive => core}/src/test/resources/test_script.py (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index e243acd7def80..d15395dbc33dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -90,14 +90,6 @@ trait BaseScriptTransformationExec extends UnaryExecNode { inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] - protected def processOutputWithoutSerde(prevLine: String, reader: BufferedReader): InternalRow = { - val limit = if (ioschema.schemaLess) 2 else 0 - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), limit) - .zip(fieldWriters) - .map { case (data, writer) => writer(data) }) - } - protected def createOutputIteratorWithoutSerde( writerThread: BaseScriptTransformationWriterThread, inputStream: InputStream, @@ -105,6 +97,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { stderrBuffer: CircularBuffer): Iterator[InternalRow] = { new Iterator[InternalRow] { var curLine: String = null + val splitLimit = if (ioschema.schemaLess) 2 else 0 val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) override def hasNext: Boolean = { @@ -133,7 +126,10 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } val prevLine = curLine curLine = reader.readLine() - processOutputWithoutSerde(prevLine, reader) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), splitLimit) + .zip(fieldWriters) + .map { case (data, writer) => writer(data) }) } } } @@ -176,7 +172,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case LongType => (data: String) => converter(data.toLong) case FloatType => (data: String) => converter(data.toFloat) case DoubleType => (data: String) => converter(data.toDouble) - case dt: DecimalType => (data: String) => converter(BigDecimal(data)) + case decimal: DecimalType => (data: String) => converter(BigDecimal(data)) case DateType if conf.datetimeJava8ApiEnabled => (data: String) => converter(DateTimeUtils.stringToDate( UTF8String.fromString(data), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 103eaf869039d..1f6c52efe8715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -65,8 +65,8 @@ case class SparkScriptTransformationExec( hadoopConf ) - val outputIterator = createOutputIteratorWithoutSerde( - writerThread, inputStream, proc, stderrBuffer) + val outputIterator = + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) writerThread.start() diff --git a/sql/hive/src/test/resources/test_script.py b/sql/core/src/test/resources/test_script.py similarity index 100% rename from sql/hive/src/test/resources/test_script.py rename to sql/core/src/test/resources/test_script.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 37a3789205b5d..350fed8ce70a2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -55,7 +55,8 @@ case class HiveScriptTransformationExec( ioschema: ScriptTransformationIOSchema) extends BaseScriptTransformationExec with HiveInspectors { - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { + private def initInputSerDe( + input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { ioschema.inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) @@ -66,7 +67,8 @@ case class HiveScriptTransformationExec( } } - def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + private def initOutputSerDe( + output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { ioschema.outputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(output) val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) @@ -104,7 +106,7 @@ case class HiveScriptTransformationExec( serde } - def recordReader( + private def recordReader( inputStream: InputStream, conf: Configuration): Option[RecordReader] = { ioschema.recordReaderClass.map { klass => @@ -119,7 +121,9 @@ case class HiveScriptTransformationExec( } } - def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + private def recordWriter( + outputStream: OutputStream, + conf: Configuration): Option[RecordWriter] = { ioschema.recordWriterClass.map { klass => val instance = Utils.classForName[RecordWriter](klass).getConstructor(). newInstance() @@ -128,7 +132,7 @@ case class HiveScriptTransformationExec( } } - private def createOtputIteratorWithSerde( + private def createOutputIteratorWithSerde( writerThread: BaseScriptTransformationWriterThread, inputStream: InputStream, proc: Process, @@ -189,25 +193,19 @@ case class HiveScriptTransformationExec( if (!hasNext) { throw new NoSuchElementException } - nextRow() - } - - val nextRow: () => InternalRow = { - () => { - val raw = outputSerde.deserialize(scriptOutputWritable) - scriptOutputWritable = null - val dataList = outputSoi.getStructFieldsDataAsList(raw) - var i = 0 - while (i < dataList.size()) { - if (dataList.get(i) == null) { - mutableRow.setNullAt(i) - } else { - unwrappers(i)(dataList.get(i), mutableRow, i) - } - i += 1 + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) + } else { + unwrappers(i)(dataList.get(i), mutableRow, i) } - mutableRow + i += 1 } + mutableRow } } } @@ -257,7 +255,7 @@ case class HiveScriptTransformationExec( val outputIterator = if (outputSerde == null) { createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) } else { - createOtputIteratorWithSerde( + createOutputIteratorWithSerde( writerThread, inputStream, proc, stderrBuffer, outputSerde, outputSoi, hadoopConf) } From 6811721cb77c9f05d51ce1d6e269400265117bf3 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 13:19:52 +0800 Subject: [PATCH 14/42] move ut and add ut for schema less --- .../BaseScriptTransformationExec.scala | 20 ++- .../BaseScriptTransformationSuite.scala | 124 +++++++++++++++++- .../SparkScriptTransformationSuite.scala | 85 +----------- 3 files changed, 139 insertions(+), 90 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index d15395dbc33dd..6b89dea39c342 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -97,9 +97,22 @@ trait BaseScriptTransformationExec extends UnaryExecNode { stderrBuffer: CircularBuffer): Iterator[InternalRow] = { new Iterator[InternalRow] { var curLine: String = null - val splitLimit = if (ioschema.schemaLess) 2 else 0 val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + val processRowWithoutSerde = if (!ioschema.schemaLess) { + prevLine: String => + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .zip(fieldWriters) + .map { case (data, writer) => writer(data) }) + } else { + prevLine: String => + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) + } + + override def hasNext: Boolean = { try { if (curLine == null) { @@ -126,10 +139,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } val prevLine = curLine curLine = reader.readLine() - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), splitLimit) - .zip(fieldWriters) - .map { case (data, writer) => writer(data) }) + processRowWithoutSerde(prevLine) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 26c08c3f513c0..1eea9a3370a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach @@ -29,8 +29,10 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with BeforeAndAfterEach { @@ -140,6 +142,51 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } + test("SPARK-25990: TRANSFORM should handle schema less correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (isHive23OrSpark) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), + script = s"python $scriptFilePath", + output = Seq( + AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), + child = child, + ioschema = defaultIOSchema.copy(schemaLess = true) + ), + df.select( + 'a.cast("string").as("key"), + concat_ws("\t", + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string"))).collect()) + } + } + test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) @@ -157,6 +204,81 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU assert(e.getMessage.contains("Subprocess exited with status")) assert(uncaughtExceptionHandler.exception.isEmpty) } + + test("SPARK-32106: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + case class Struct(d: Int, str: String) + withTempView("v") { + val df = Seq( + (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), + Map("a" -> 1)), + (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), + Map("b" -> 2)), + (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), + Map("c" -> 3)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, struct('a, 'b).as("k")) + // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + assert(spark.table("v").schema == + StructType(Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true), + StructField("c", DoubleType, false), + StructField("d", ByteType, false), + StructField("e", DecimalType(38, 18), true), + StructField("f", TimestampType, true), + StructField("g", DateType, true), + StructField("h", CalendarIntervalType, true), + StructField("i", ArrayType(IntegerType, false), true), + StructField("j", MapType(StringType, IntegerType, false), true), + StructField("k", StructType( + Seq(StructField("a", IntegerType, false), + StructField("b", StringType, true))), false)))) + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr, + df.col("k").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", ByteType)(), + AttributeReference("e", DecimalType(38, 18))(), + AttributeReference("f", TimestampType)(), + AttributeReference("g", DateType)(), + AttributeReference("h", CalendarIntervalType)(), + AttributeReference("i", StringType)(), + AttributeReference("j", StringType)(), + AttributeReference("k", StringType)()), + child = child, + ioschema = defaultIOSchema + ), + df.select( + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, + 'i.cast("string"), + 'j.cast("string"), + 'k.cast("string")).collect()) + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 1abf298a6123e..68f070a85a12d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -17,18 +17,10 @@ package org.apache.spark.sql.execution -import java.sql.{Date, Timestamp} - -import org.apache.spark.TestUtils -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { - import testImplicits._ - import ScriptTransformationIOSchema._ override def isHive23OrSpark: Boolean = true @@ -46,79 +38,4 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with ioschema = ioschema ) } - - test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { - assume(TestUtils.testCommandAvailable("python")) - case class Struct(d: Int, str: String) - withTempView("v") { - val df = Seq( - (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), - new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1)), - (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), - new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2)), - (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), - new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, struct('a, 'b).as("k")) - // Note column d's data type is Decimal(38, 18) - df.createTempView("v") - - assert(spark.table("v").schema == - StructType(Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true), - StructField("c", DoubleType, false), - StructField("d", ByteType, false), - StructField("e", DecimalType(38, 18), true), - StructField("f", TimestampType, true), - StructField("g", DateType, true), - StructField("h", CalendarIntervalType, true), - StructField("i", ArrayType(IntegerType, false), true), - StructField("j", MapType(StringType, IntegerType, false), true), - StructField("k", StructType( - Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true))), false)))) - - // Can't support convert script output data to ArrayType/MapType/StructType now, - // return these column still as string - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("a").expr, - df.col("b").expr, - df.col("c").expr, - df.col("d").expr, - df.col("e").expr, - df.col("f").expr, - df.col("g").expr, - df.col("h").expr, - df.col("i").expr, - df.col("j").expr, - df.col("k").expr), - script = "cat", - output = Seq( - AttributeReference("a", IntegerType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", ByteType)(), - AttributeReference("e", DecimalType(38, 18))(), - AttributeReference("f", TimestampType)(), - AttributeReference("g", DateType)(), - AttributeReference("h", CalendarIntervalType)(), - AttributeReference("i", StringType)(), - AttributeReference("j", StringType)(), - AttributeReference("k", StringType)()), - child = child, - ioschema = defaultIOSchema - ), - df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, - 'i.cast("string"), - 'j.cast("string"), - 'k.cast("string")).collect()) - } - } } From fc96e1fa2d224616b31813e7ef827f08ef4b9967 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 17:22:10 +0800 Subject: [PATCH 15/42] follow comment --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 1 - .../spark/sql/execution/BaseScriptTransformationSuite.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 6b89dea39c342..f4e0722e0be95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -112,7 +112,6 @@ trait BaseScriptTransformationExec extends UnaryExecNode { .map(CatalystTypeConverters.convertToCatalyst)) } - override def hasNext: Boolean = { try { if (curLine == null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 1eea9a3370a69..eaf4d881abd84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -205,7 +205,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-32106: TRANSFORM should handle different data types correctly") { + test("SPARK-32106: TRANSFORM should support more data types (no serde)") { assume(TestUtils.testCommandAvailable("python")) case class Struct(d: Int, str: String) withTempView("v") { From ed901afc3ef115af06fef6047a21b472b7a867e5 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 17:29:03 +0800 Subject: [PATCH 16/42] Update BaseScriptTransformationExec.scala --- .../sql/execution/BaseScriptTransformationExec.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index f4e0722e0be95..e9bfbd2101a9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -182,21 +182,11 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case FloatType => (data: String) => converter(data.toFloat) case DoubleType => (data: String) => converter(data.toDouble) case decimal: DecimalType => (data: String) => converter(BigDecimal(data)) - case DateType if conf.datetimeJava8ApiEnabled => (data: String) => - converter(DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull) case DateType => (data: String) => converter(DateTimeUtils.stringToDate( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) .map(DateTimeUtils.toJavaDate).orNull) - case TimestampType if conf.datetimeJava8ApiEnabled => (data: String) => - converter(DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull) case TimestampType => (data: String) => converter(DateTimeUtils.stringToTimestamp( UTF8String.fromString(data), From a6f1e7d88de983031aa5d3235a7ac497c70946a1 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 19:28:28 +0800 Subject: [PATCH 17/42] catch data convert exception --- .../BaseScriptTransformationExec.scala | 52 +++++++++++-------- .../BaseScriptTransformationSuite.scala | 19 ++++++- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index e9bfbd2101a9f..5aa4cc6b9614d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -174,29 +174,39 @@ trait BaseScriptTransformationExec extends UnaryExecNode { private lazy val fieldWriters: Seq[String => Any] = output.map { attr => val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) attr.dataType match { - case StringType => (data: String) => converter(data) - case ByteType => (data: String) => converter(data.toByte) - case IntegerType => (data: String) => converter(data.toInt) - case ShortType => (data: String) => converter(data.toShort) - case LongType => (data: String) => converter(data.toLong) - case FloatType => (data: String) => converter(data.toFloat) - case DoubleType => (data: String) => converter(data.toDouble) - case decimal: DecimalType => (data: String) => converter(BigDecimal(data)) - case DateType => (data: String) => - converter(DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull) - case TimestampType => (data: String) => - converter(DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull) - case CalendarIntervalType => (data: String) => - converter(IntervalUtils.stringToInterval(UTF8String.fromString(data))) - case dataType: DataType => (data: String) => converter(data) + case StringType => wrapperConvertException(data => data, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case _: DataType => wrapperConvertException(data => data, converter) } } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case _: Exception => null + } + } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index eaf4d881abd84..6f88e23847848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Column +import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning @@ -279,6 +279,23 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'k.cast("string")).collect()) } } + + test("SPARK-32106: TRANSFORM shoud return null when return string incompatible(no serde)") { + checkAnswer( + sql( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' as (a int, b int , c int) + |FROM ( + |SELECT + |1 AS a, + |"a" AS b, + |CAST(2000 AS timestamp) AS c + |) tmp + """.stripMargin), + identity, + Row(1, null, null) :: Nil) + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From e367c0544298f6639e8898029eab5e29ea1f91ea Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 22:07:20 +0800 Subject: [PATCH 18/42] add UTD support --- .../BaseScriptTransformationExec.scala | 2 + .../BaseScriptTransformationSuite.scala | 93 ++++++++++++++++--- 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 5aa4cc6b9614d..ddd7a5ec0fb72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -193,6 +193,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case CalendarIntervalType => wrapperConvertException( data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) case _: DataType => wrapperConvertException(data => data, converter) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 6f88e23847848..914dc3b61d9a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.execution import java.sql.{Date, Timestamp} +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach import org.scalatest.exceptions.TestFailedException @@ -27,7 +30,7 @@ import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils @@ -212,20 +215,21 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU val df = Seq( (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1)), + Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2)), + Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, struct('a, 'b).as("k")) + Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, struct('a, 'b).as("m")) // Note column d's data type is Decimal(38, 18) df.createTempView("v") assert(spark.table("v").schema == - StructType(Seq(StructField("a", IntegerType, false), + StructType(Seq( + StructField("a", IntegerType, false), StructField("b", StringType, true), StructField("c", DoubleType, false), StructField("d", ByteType, false), @@ -235,12 +239,17 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU StructField("h", CalendarIntervalType, true), StructField("i", ArrayType(IntegerType, false), true), StructField("j", MapType(StringType, IntegerType, false), true), - StructField("k", StructType( + StructField("k", new TestUDT.MyDenseVectorUDT, true), + StructField("l", new SimpleTupleUDT, true), + StructField("m", StructType( Seq(StructField("a", IntegerType, false), StructField("b", StringType, true))), false)))) // Can't support convert script output data to ArrayType/MapType/StructType now, - // return these column still as string + // return these column still as string. + // For UserDefinedType, if user defined deserialize method to support convert string + // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we + // will return null value as column. checkAnswer( df, (child: SparkPlan) => createScriptTransformationExec( @@ -255,7 +264,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU df.col("h").expr, df.col("i").expr, df.col("j").expr, - df.col("k").expr), + df.col("k").expr, + df.col("l").expr, + df.col("m").expr), script = "cat", output = Seq( AttributeReference("a", IntegerType)(), @@ -268,7 +279,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("h", CalendarIntervalType)(), AttributeReference("i", StringType)(), AttributeReference("j", StringType)(), - AttributeReference("k", StringType)()), + AttributeReference("k", StringType)(), + AttributeReference("l", new SimpleTupleUDT)(), + AttributeReference("m", StringType)()), child = child, ioschema = defaultIOSchema ), @@ -276,7 +289,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i.cast("string"), 'j.cast("string"), - 'k.cast("string")).collect()) + 'k.cast("string"), + 'l, 'm.cast("string")).collect()) } } @@ -311,3 +325,58 @@ case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { override def outputPartitioning: Partitioning = child.outputPartitioning } + +@SQLUserDefinedType(udt = classOf[SimpleTupleUDT]) +private class SimpleTuple(val id: Int, val size: Long) extends Serializable { + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other match { + case v: SimpleTuple => this.id == v.id && this.size == v.size + case _ => false + } + + override def toString: String = + compact(render( + ("id" -> id) ~ + ("size" -> size) + )) +} + +private class SimpleTupleUDT extends UserDefinedType[SimpleTuple] { + + override def sqlType: DataType = StructType( + StructField("id", IntegerType, false) :: + StructField("size", LongType, false) :: + Nil) + + override def serialize(sql: SimpleTuple): Any = { + val row = new GenericInternalRow(2) + row.setInt(0, sql.id) + row.setLong(1, sql.size) + row + } + + override def deserialize(datum: Any): SimpleTuple = { + datum match { + case str: String => + implicit val format = DefaultFormats + val json = parse(str) + new SimpleTuple((json \ "id").extract[Int], (json \ "size").extract[Long]) + case data: InternalRow if data.numFields == 2 => + new SimpleTuple(data.getInt(0), data.getLong(1)) + case _ => null + } + } + + override def userClass: Class[SimpleTuple] = classOf[SimpleTuple] + + override def asNullable: SimpleTupleUDT = this + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = { + other.isInstanceOf[SimpleTupleUDT] + } +} + From e74d04c6b20f1f0becf6a0a41d9afc380f0e6f4d Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sat, 18 Jul 2020 22:44:18 +0800 Subject: [PATCH 19/42] add test --- .../resources/sql-tests/inputs/transform.sql | 47 +++++++++ .../sql-tests/results/transform.sql.out | 97 +++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/transform.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/transform.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql new file mode 100644 index 0000000000000..6c6d5fc51583b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -0,0 +1,47 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +('a'), ('b'), ('v') +as t1(a); + +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(1, '1', 1.0, Decimal(1.0), timestamp(1)), +(2, '2', 2.0, Decimal(2.0), timestamp(2)), +(2, '2', 3.0, Decimal(3.0), timestamp(3)) +as t2(a,b,c,d,e); + +SELECT TRANSFORM(a) +USING 'cat' AS (a) +FROM t1; + + +-- with non-exist command +SELECT TRANSFORM(a) +USING 'some_non_existent_command' AS (a) +FROM t1; + +-- with non-exist file +SELECT TRANSFORM(a) +USING 'python some_non_existent_file' AS (a) +FROM t1; + + +-- support different data type +SELECT TRANSFORM(a, b, c, d, e) +USING 'CAT' AS (a, b, c, d, e) +FROM t2; + + +-- handle schema less +SELECT TRANSFORM(a, b) +USING 'CAT' +FROM t2; + +-- return null when return string incompatible(no serde) +SELECT TRANSFORM(a, b, c) +USING 'cat' as (a int, b int , c int) +FROM ( +SELECT +1 AS a, +"a" AS b, +CAST(2000 AS timestamp) AS c +) tmp; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out new file mode 100644 index 0000000000000..2ddf37e8f1edf --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -0,0 +1,97 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +('a'), ('b'), ('v') +as t1(a) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(1, '1', 1.0, Decimal(1.0), timestamp(1)), +(2, '2', 2.0, Decimal(2.0), timestamp(2)), +(2, '2', 3.0, Decimal(3.0), timestamp(3)) +as t2(a,b,c,d,e) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT TRANSFORM(a) +USING 'cat' AS (a) +FROM t1 +-- !query schema +struct +-- !query output +a +b +v + + +-- !query +SELECT TRANSFORM(a) +USING 'some_non_existent_command' AS (a) +FROM t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +Subprocess exited with status 127. Error: /bin/bash: some_non_existent_command: command not found + + +-- !query +SELECT TRANSFORM(a) +USING 'python some_non_existent_file' AS (a) +FROM t1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +Subprocess exited with status 2. Error: python: can't open file 'some_non_existent_file': [Errno 2] No such file or directory + + +-- !query +SELECT TRANSFORM(a, b, c, d, e) +USING 'CAT' AS (a, b, c, d, e) +FROM t2 +-- !query schema +struct +-- !query output +1 1 1.0 1 1969-12-31 16:00:01 +2 2 2.0 2 1969-12-31 16:00:02 +2 2 3.0 3 1969-12-31 16:00:03 + + +-- !query +SELECT TRANSFORM(a, b) +USING 'CAT' +FROM t2 +-- !query schema +struct +-- !query output +1 1 +2 2 +2 2 + + +-- !query +SELECT TRANSFORM(a, b, c) +USING 'cat' as (a int, b int , c int) +FROM ( +SELECT +1 AS a, +"a" AS b, +CAST(2000 AS timestamp) AS c +) tmp +-- !query schema +struct +-- !query output +1 NULL NULL From 4ef4d76bfd0e044dc4d5a0a9d674770a35ede408 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Sun, 19 Jul 2020 13:00:15 +0800 Subject: [PATCH 20/42] add data type --- .../BaseScriptTransformationExec.scala | 2 + .../resources/sql-tests/inputs/transform.sql | 24 ++++++------ .../sql-tests/results/transform.sql.out | 38 ++++++++++--------- .../BaseScriptTransformationSuite.scala | 18 ++++++--- 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index ddd7a5ec0fb72..42da37c160cfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -175,7 +175,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode { val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) attr.dataType match { case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => wrapperConvertException(data => data.getBytes, converter) case IntegerType => wrapperConvertException(data => data.toInt, converter) case ShortType => wrapperConvertException(data => data.toShort, converter) case LongType => wrapperConvertException(data => data.toLong, converter) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 6c6d5fc51583b..dc30dd43b36ea 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -4,10 +4,10 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES as t1(a); CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES -(1, '1', 1.0, Decimal(1.0), timestamp(1)), -(2, '2', 2.0, Decimal(2.0), timestamp(2)), -(2, '2', 3.0, Decimal(3.0), timestamp(3)) -as t2(a,b,c,d,e); +('1', true, unhex('537061726B2053514C'), tinyint(1), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp(1), current_date), +('2', false, unhex('537061726B2053514C'), tinyint(2), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp(2), current_date), +('3', true, unhex('537061726B2053514C'), tinyint(3), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp(3), current_date) +as t2(a,b,c,d,e,f,g,h,i,j); SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -26,9 +26,11 @@ FROM t1; -- support different data type -SELECT TRANSFORM(a, b, c, d, e) -USING 'CAT' AS (a, b, c, d, e) -FROM t2; +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) + USING 'CAT' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) + FROM t2 +) tmp; -- handle schema less @@ -40,8 +42,8 @@ FROM t2; SELECT TRANSFORM(a, b, c) USING 'cat' as (a int, b int , c int) FROM ( -SELECT -1 AS a, -"a" AS b, -CAST(2000 AS timestamp) AS c + SELECT + 1 AS a, + "a" AS b, + CAST(2000 AS timestamp) AS c ) tmp; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 2ddf37e8f1edf..b86d40d27cc6a 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -14,10 +14,10 @@ struct<> -- !query CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES -(1, '1', 1.0, Decimal(1.0), timestamp(1)), -(2, '2', 2.0, Decimal(2.0), timestamp(2)), -(2, '2', 3.0, Decimal(3.0), timestamp(3)) -as t2(a,b,c,d,e) +('1', true, unhex('537061726B2053514C'), tinyint(1), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp(1), current_date), +('2', false, unhex('537061726B2053514C'), tinyint(2), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp(2), current_date), +('3', true, unhex('537061726B2053514C'), tinyint(3), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp(3), current_date) +as t2(a,b,c,d,e,f,g,h,i,j) -- !query schema struct<> -- !query output @@ -59,15 +59,17 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query -SELECT TRANSFORM(a, b, c, d, e) -USING 'CAT' AS (a, b, c, d, e) -FROM t2 +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) + USING 'CAT' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) + FROM t2 +) tmp -- !query schema -struct +struct -- !query output -1 1 1.0 1 1969-12-31 16:00:01 -2 2 2.0 2 1969-12-31 16:00:02 -2 2 3.0 3 1969-12-31 16:00:03 +1 true Spark SQL 1 3 1.0 1.0 1.000000000000000000 1969-12-31 16:00:01 2020-07-18 +2 false Spark SQL 2 2 2.0 2.0 2.000000000000000000 1969-12-31 16:00:02 2020-07-18 +3 true Spark SQL 3 3 3.0 3.0 3.000000000000000000 1969-12-31 16:00:03 2020-07-18 -- !query @@ -77,19 +79,19 @@ FROM t2 -- !query schema struct -- !query output -1 1 -2 2 -2 2 +1 true +2 false +3 true -- !query SELECT TRANSFORM(a, b, c) USING 'cat' as (a int, b int , c int) FROM ( -SELECT -1 AS a, -"a" AS b, -CAST(2000 AS timestamp) AS c + SELECT + 1 AS a, + "a" AS b, + CAST(2000 AS timestamp) AS c ) tmp -- !query schema struct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 914dc3b61d9a1..296486ad77a6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -223,7 +223,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, struct('a, 'b).as("m")) + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, + struct('a, 'b).as("m"), unhex('a).as("n"), lit(true).as("o") + ) // Note column d's data type is Decimal(38, 18) df.createTempView("v") @@ -243,7 +245,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU StructField("l", new SimpleTupleUDT, true), StructField("m", StructType( Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true))), false)))) + StructField("b", StringType, true))), false), + StructField("n", BinaryType, true), + StructField("o", BooleanType, false)))) // Can't support convert script output data to ArrayType/MapType/StructType now, // return these column still as string. @@ -266,7 +270,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU df.col("j").expr, df.col("k").expr, df.col("l").expr, - df.col("m").expr), + df.col("m").expr, + df.col("n").expr, + df.col("o").expr), script = "cat", output = Seq( AttributeReference("a", IntegerType)(), @@ -281,7 +287,9 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("j", StringType)(), AttributeReference("k", StringType)(), AttributeReference("l", new SimpleTupleUDT)(), - AttributeReference("m", StringType)()), + AttributeReference("m", StringType)(), + AttributeReference("n", BinaryType)(), + AttributeReference("o", BooleanType)()), child = child, ioschema = defaultIOSchema ), @@ -290,7 +298,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'i.cast("string"), 'j.cast("string"), 'k.cast("string"), - 'l, 'm.cast("string")).collect()) + 'l, 'm.cast("string"), 'n, 'o).collect()) } } From 22d223c92435afca043d86d1f02041da58d79679 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 20 Jul 2020 06:51:46 +0800 Subject: [PATCH 21/42] fix ut --- sql/core/src/test/resources/sql-tests/inputs/transform.sql | 4 ++-- .../src/test/resources/sql-tests/results/transform.sql.out | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index dc30dd43b36ea..5ee355601addd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -28,14 +28,14 @@ FROM t1; -- support different data type SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) - USING 'CAT' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) + USING 'cat' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) FROM t2 ) tmp; -- handle schema less SELECT TRANSFORM(a, b) -USING 'CAT' +USING 'cat' FROM t2; -- return null when return string incompatible(no serde) diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index b86d40d27cc6a..57226bdd2b7b4 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -61,7 +61,7 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) - USING 'CAT' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) + USING 'cat' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) FROM t2 ) tmp -- !query schema @@ -74,7 +74,7 @@ struct From 72b215558b5d3e326ebe2416367a9d33455f9d58 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 20 Jul 2020 11:43:19 +0800 Subject: [PATCH 22/42] added UT --- .../BaseScriptTransformationExec.scala | 15 ++- .../SparkScriptTransformationExec.scala | 6 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 3 +- .../BaseScriptTransformationSuite.scala | 27 +---- .../HiveScriptTransformationExec.scala | 2 +- .../HiveScriptTransformationSuite.scala | 100 +++++++++++++++++- 6 files changed, 116 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 42da37c160cfa..8866ec74ec535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} @@ -46,6 +46,10 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def child: SparkPlan def ioschema: ScriptTransformationIOSchema + protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } + override def producedAttributes: AttributeSet = outputSet -- inputSet override def outputPartitioning: Partitioning = child.outputPartitioning @@ -99,16 +103,17 @@ trait BaseScriptTransformationExec extends UnaryExecNode { var curLine: String = null val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") val processRowWithoutSerde = if (!ioschema.schemaLess) { prevLine: String => new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + prevLine.split(outputRowFormat) .zip(fieldWriters) .map { case (data, writer) => writer(data) }) } else { prevLine: String => new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + prevLine.split(outputRowFormat, 2) .map(CatalystTypeConverters.convertToCatalyst)) } @@ -197,12 +202,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode { converter) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) - case _: DataType => wrapperConvertException(data => data, converter) + case _ => wrapperConvertException(data => data, converter) } } // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null - val wrapperConvertException: (String => Any, Any => Any) => String => Any = + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = (f: String => Any, converter: Any => Any) => (data: String) => converter { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 1f6c52efe8715..b87c20e6a5656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -48,15 +48,13 @@ case class SparkScriptTransformationExec( val (outputStream, proc, inputStream, stderrBuffer) = initProc - val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) - - val outputProjection = new InterpretedProjection(finalInput, child.output) + val outputProjection = new InterpretedProjection(inputExpressionsWithoutSerde, child.output) // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. val writerThread = SparkScriptTransformationWriterThread( inputIterator.map(outputProjection), - finalInput.map(_.dataType), + inputExpressionsWithoutSerde.map(_.dataType), ioschema, outputStream, proc, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index f0522dfeafaac..36d7eeef44868 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -23,7 +23,7 @@ import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, TestUtils} import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -258,6 +258,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER") } + assume(TestUtils.testCommandAvailable("/bin/bash")) val input = fileToString(new File(testCase.inputFile)) val (comments, code) = splitCommentsAndCodes(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 296486ad77a6c..a4cc44e7c774f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -210,7 +210,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU test("SPARK-32106: TRANSFORM should support more data types (no serde)") { assume(TestUtils.testCommandAvailable("python")) - case class Struct(d: Int, str: String) withTempView("v") { val df = Seq( (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), @@ -225,29 +224,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, struct('a, 'b).as("m"), unhex('a).as("n"), lit(true).as("o") - ) - // Note column d's data type is Decimal(38, 18) - df.createTempView("v") - - assert(spark.table("v").schema == - StructType(Seq( - StructField("a", IntegerType, false), - StructField("b", StringType, true), - StructField("c", DoubleType, false), - StructField("d", ByteType, false), - StructField("e", DecimalType(38, 18), true), - StructField("f", TimestampType, true), - StructField("g", DateType, true), - StructField("h", CalendarIntervalType, true), - StructField("i", ArrayType(IntegerType, false), true), - StructField("j", MapType(StringType, IntegerType, false), true), - StructField("k", new TestUDT.MyDenseVectorUDT, true), - StructField("l", new SimpleTupleUDT, true), - StructField("m", StructType( - Seq(StructField("a", IntegerType, false), - StructField("b", StringType, true))), false), - StructField("n", BinaryType, true), - StructField("o", BooleanType, false)))) + ) // Note column d's data type is Decimal(38, 18) // Can't support convert script output data to ArrayType/MapType/StructType now, // return these column still as string. @@ -302,7 +279,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-32106: TRANSFORM shoud return null when return string incompatible(no serde)") { + test("SPARK-32106: TRANSFORM should return null when return string incompatible(no serde)") { checkAnswer( sql( """ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 350fed8ce70a2..69b5b493394be 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -223,7 +223,7 @@ case class HiveScriptTransformationExec( // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null // We will use StringBuffer to pass data, in this case, we should cast data as string too. val finalInput = if (inputSerde == null) { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + inputExpressionsWithoutSerde } else { input } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 7ba1deb101a65..f557b7683c617 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.sql.hive.execution +import java.sql.Timestamp + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with TestHiveSingleton { import testImplicits._ + import ScriptTransformationIOSchema._ override def isHive23OrSpark: Boolean = HiveUtils.isHive23 @@ -149,5 +154,98 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(e.getMessage.contains("Subprocess exited with status")) assert(uncaughtExceptionHandler.exception.isEmpty) } + + test("SPARK-25990: TRANSFORM should handle schema less correctly (with hive serde)") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT TRANSFORM(a, b, c, d, e) + |USING 'python ${scriptFilePath}' + |FROM v + """.stripMargin) + + // In hive default serde mode, if we don't define output schema, it will choose first + // two column as output schema (key: String, value: String) + checkAnswer( + query, + identity, + df.select( + 'a.cast("string").as("key"), + 'b.cast("string").as("value")).collect()) + } + } + + test("SPARK-32106: TRANSFORM support complex data types as input and ouput type (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, "1", Array(0, 1, 2), Map("a" -> 1)), + (2, "2", Array(3, 4, 5), Map("b" -> 2))) + .toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + df.createTempView("v") + + // Hive serde support ArrayType/MapType/StructType as input and output data type + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), + script = "cat", + output = Seq( + AttributeReference("c", ArrayType(IntegerType))(), + AttributeReference("d", MapType(StringType, IntegerType))(), + AttributeReference("e", StructType( + Seq( + StructField("col1", IntegerType, false), + StructField("col2", StringType, true))))()), + child = child, + ioschema = serdeIOSchema + ), + df.select('c, 'd, 'e).collect()) + } + } + + test("SPARK-32106: TRANSFORM don't support CalenderIntervalType/UserDefinedType (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3))), + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3)))) + .toDF("a", "b", "c") + df.createTempView("v") + + val e1 = intercept[Exception] { + sql( + """ + |SELECT TRANSFORM(a, b) USING 'cat' AS (a, b) + |FROM v + """.stripMargin).collect() + } + assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) + + val e2 = intercept[Exception] { + sql( + """ + |SELECT TRANSFORM(a, c) USING 'cat' AS (a, c) + |FROM v + """.stripMargin).collect() + } + assert(e2.getMessage.contains( + "scala.MatchError: org.apache.spark.sql.types.TestUDT$MyDenseVectorUDT")) + } + } } From a3628ac576ef9fbe06e87ad4ff36043897e0056a Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 20 Jul 2020 11:56:30 +0800 Subject: [PATCH 23/42] update --- .../resources/sql-tests/inputs/transform.sql | 53 ++++++++--- .../sql-tests/results/transform.sql.out | 95 +++++++++++++------ 2 files changed, 106 insertions(+), 42 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 5ee355601addd..196341c26bc9d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -1,13 +1,9 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('a'), ('b'), ('v') -as t1(a); - -CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp(1), current_date), -('2', false, unhex('537061726B2053514C'), tinyint(2), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp(2), current_date), -('3', true, unhex('537061726B2053514C'), tinyint(3), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp(3), current_date) -as t2(a,b,c,d,e,f,g,h,i,j); +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k); SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -26,19 +22,30 @@ FROM t1; -- support different data type -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) - USING 'cat' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) - FROM t2 +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) + USING 'cat' AS ( + a string, + b boolean, + c binary, + d tinyint, + e smallint, + f long, + g float, + h double, + i decimal(38, 18), + j timestamp, + k date) + FROM t1 ) tmp; -- handle schema less SELECT TRANSFORM(a, b) USING 'cat' -FROM t2; +FROM t1; --- return null when return string incompatible(no serde) +-- return null when return string incompatible (no serde) SELECT TRANSFORM(a, b, c) USING 'cat' as (a int, b int , c int) FROM ( @@ -46,4 +53,20 @@ FROM ( 1 AS a, "a" AS b, CAST(2000 AS timestamp) AS c -) tmp; \ No newline at end of file +) tmp; + + +-- transform can't run with aggregation +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t1 +GROUP BY b; + +-- transform use MAP +MAP a, b USING 'cat' AS (a, b) FROM t1; + + +-- transform use REDUCE +REDUCE a, b USING 'cat' AS (a, b) FROM t1; + + diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 57226bdd2b7b4..8e35efdf3fd2a 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,23 +1,13 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 10 -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('a'), ('b'), ('v') -as t1(a) --- !query schema -struct<> --- !query output - - - --- !query -CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp(1), current_date), -('2', false, unhex('537061726B2053514C'), tinyint(2), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp(2), current_date), -('3', true, unhex('537061726B2053514C'), tinyint(3), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp(3), current_date) -as t2(a,b,c,d,e,f,g,h,i,j) +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k) -- !query schema struct<> -- !query output @@ -31,9 +21,9 @@ FROM t1 -- !query schema struct -- !query output -a -b -v +1 +2 +3 -- !query @@ -59,23 +49,34 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j) - USING 'cat' AS (a string, b boolean, c binary, d tinyint, e long, f float, g double, h decimal(38, 18), i timestamp, j date) - FROM t2 +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) + USING 'cat' AS ( + a string, + b boolean, + c binary, + d tinyint, + e smallint, + f long, + g float, + h double, + i decimal(38, 18), + j timestamp, + k date) + FROM t1 ) tmp -- !query schema -struct +struct -- !query output -1 true Spark SQL 1 3 1.0 1.0 1.000000000000000000 1969-12-31 16:00:01 2020-07-18 -2 false Spark SQL 2 2 2.0 2.0 2.000000000000000000 1969-12-31 16:00:02 2020-07-18 -3 true Spark SQL 3 3 3.0 3.0 3.000000000000000000 1969-12-31 16:00:03 2020-07-18 +1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 -- !query SELECT TRANSFORM(a, b) USING 'cat' -FROM t2 +FROM t1 -- !query schema struct -- !query output @@ -97,3 +98,43 @@ FROM ( struct -- !query output 1 NULL NULL + + +-- !query +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t1 +GROUP BY b +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'GROUP' expecting {, ';'}(line 4, pos 0) + +== SQL == +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t1 +GROUP BY b +^^^ + + +-- !query +MAP a, b USING 'cat' AS (a, b) FROM t1 +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + +-- !query +REDUCE a, b USING 'cat' AS (a, b) FROM t1 +-- !query schema +struct +-- !query output +1 true +2 false +3 true From e16c13620032f8062cb0fcd6ecad9836c97febf7 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 20 Jul 2020 12:04:38 +0800 Subject: [PATCH 24/42] update title --- .../spark/sql/execution/BaseScriptTransformationSuite.scala | 6 +++--- .../sql/hive/execution/HiveScriptTransformationSuite.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index a4cc44e7c774f..6e4362eba025f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -145,7 +145,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-25990: TRANSFORM should handle schema less correctly") { + test("SPARK-25990: TRANSFORM should handle schema less correctly (no serde)") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = getTestResourcePath("test_script.py") @@ -208,7 +208,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-32106: TRANSFORM should support more data types (no serde)") { + test("SPARK-32106: TRANSFORM should support all data types as input (no serde)") { assume(TestUtils.testCommandAvailable("python")) withTempView("v") { val df = Seq( @@ -279,7 +279,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-32106: TRANSFORM should return null when return string incompatible(no serde)") { + test("SPARK-32106: TRANSFORM should return null when return string incompatible") { checkAnswer( sql( """ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index f557b7683c617..ae2d581a73b04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -155,7 +155,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-25990: TRANSFORM should handle schema less correctly (with hive serde)") { + test("SPARK-25990: TRANSFORM should handle schema less correctly (hive serde)") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = getTestResourcePath("test_script.py") From 858f4e5327735e7b8a0ef5d1b581c8c825d29eb8 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 20 Jul 2020 16:19:57 +0800 Subject: [PATCH 25/42] Update BaseScriptTransformationExec.scala --- .../sql/execution/BaseScriptTransformationExec.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 8866ec74ec535..b0368cb83114d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -189,10 +189,20 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case FloatType => wrapperConvertException(data => data.toFloat, converter) case DoubleType => wrapperConvertException(data => data.toDouble, converter) case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( UTF8String.fromString(data), DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) From cfecc90861ecae94a90e37654412fb31e934d14e Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 21 Jul 2020 14:32:09 +0800 Subject: [PATCH 26/42] support array map struct --- .../BaseScriptTransformationExec.scala | 106 ++++---- .../spark/sql/execution/SparkInspectors.scala | 250 ++++++++++++++++++ .../resources/sql-tests/inputs/transform.sql | 20 +- .../sql-tests/results/transform.sql.out | 28 +- .../BaseScriptTransformationSuite.scala | 16 +- 5 files changed, 339 insertions(+), 81 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index b0368cb83114d..c94b02b43398c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets +import java.util.Map.Entry import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -47,7 +46,13 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def ioschema: ScriptTransformationIOSchema protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + input.map { in: Expression => + in.dataType match { + case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in + case _ => Cast(in, StringType) + .withTimeZone(conf.sessionLocalTimeZone).asInstanceOf[Expression] + } + } } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -177,55 +182,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } private lazy val fieldWriters: Seq[String => Any] = output.map { attr => - val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) - attr.dataType match { - case StringType => wrapperConvertException(data => data, converter) - case BooleanType => wrapperConvertException(data => data.toBoolean, converter) - case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => wrapperConvertException(data => data.getBytes, converter) - case IntegerType => wrapperConvertException(data => data.toInt, converter) - case ShortType => wrapperConvertException(data => data.toShort, converter) - case LongType => wrapperConvertException(data => data.toLong, converter) - case FloatType => wrapperConvertException(data => data.toFloat, converter) - case DoubleType => wrapperConvertException(data => data.toDouble, converter) - case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) - case DateType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull, converter) - case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull, converter) - case TimestampType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull, converter) - case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull, converter) - case CalendarIntervalType => wrapperConvertException( - data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), - converter) - case udt: UserDefinedType[_] => - wrapperConvertException(data => udt.deserialize(data), converter) - case _ => wrapperConvertException(data => data, converter) - } + SparkInspectors.unwrapper(attr.dataType, conf, ioschema) } - - // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null - private val wrapperConvertException: (String => Any, Any => Any) => String => Any = - (f: String => Any, converter: Any => Any) => - (data: String) => converter { - try { - f(data) - } catch { - case _: Exception => null - } - } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { @@ -248,18 +206,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging protected def processRows(): Unit + val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) + protected def processRowsWithoutSerde(): Unit = { val len = inputSchema.length iter.foreach { row => + val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { + case (value, wrapper) => wrapper(value) + } val data = if (len == 0) { ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) + buildString(sb, values(0), inputSchema(0)) var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) + buildString(sb, values(i), inputSchema(i)) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -269,6 +232,38 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging } } + private def buildString(sb: StringBuilder, obj: Any, dataType: DataType): Unit = { + (obj, dataType) match { + case (arrayList: java.util.ArrayList[_], StructType(fields)) => + (0 until arrayList.size).foreach { i => + if (i > 0) { + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATSTRUCTFIELD")) + } + buildString(sb, arrayList.get(i), fields(i).dataType) + } + case (list: java.util.List[_], ArrayType(typ, _)) => + (0 until list.size).foreach { i => + if (i > 0) { + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) + } + buildString(sb, list.get(i), typ) + } + case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val entries = map.entrySet().toArray() + (0 until entries.size).foreach { i => + if (i > 0) { + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) + } + val entry = entries(i).asInstanceOf[Entry[_, _]] + buildString(sb, entry.getKey, keyType) + sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS")) + buildString(sb, entry.getValue, valueType) + } + case (other, _) => + sb.append(other.toString) + } + } + override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -328,7 +323,10 @@ case class ScriptTransformationIOSchema( object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n") + ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATSTRUCTFIELD", "\u0001"), + ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), + ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") ) val defaultIOSchema = ScriptTransformationIOSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala new file mode 100644 index 0000000000000..f5a32eb401a7a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, IntervalUtils, MapData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SparkInspectors { + + def wrapper(dataType: DataType): Any => Any = dataType match { + case ArrayType(tpe, _) => + val wp = wrapper(tpe) + withNullSafe { o => + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(tpe, (_, e) => values.add(wp(e))) + values + } + case MapType(keyType, valueType, _) => + val mt = dataType.asInstanceOf[MapType] + val keyWrapper = wrapper(keyType) + val valueWrapper = wrapper(valueType) + withNullSafe { o => + val map = o.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) + jmap + } + case StringType => getStringWritable + case IntegerType => getIntWritable + case DoubleType => getDoubleWritable + case BooleanType => getBooleanWritable + case LongType => getLongWritable + case FloatType => getFloatWritable + case ShortType => getShortWritable + case ByteType => getByteWritable + case NullType => (_: Any) => null + case BinaryType => getBinaryWritable + case DateType => getDateWritable + case TimestampType => getTimestampWritable + // TODO decimal precision? + case DecimalType() => getDecimalWritable + case StructType(fields) => + val structType = dataType.asInstanceOf[StructType] + val wrappers = fields.map(f => wrapper(f.dataType)) + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](wrappers.size) + wrappers.zipWithIndex.foreach { + case (wrapper, i) => + val tpe = structType(i).dataType + result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) + } + result + } + case _: UserDefinedType[_] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + wrapper(sqlType) + } + + private def withNullSafe(f: Any => Any): Any => Any = { + input => + if (input == null) { + null + } else { + f(input) + } + } + + private def getStringWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[UTF8String].toString + } + + private def getIntWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Int] + } + + private def getDoubleWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Double] + } + + private def getBooleanWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Boolean] + } + + private def getLongWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Long] + } + + private def getFloatWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Float] + } + + private def getShortWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Short] + } + + private def getByteWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Byte] + } + + private def getBinaryWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Array[Byte]] + } + + private def getDateWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaDate(value.asInstanceOf[Int]) + } + + private def getTimestampWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]) + } + + private def getDecimalWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Decimal] + } + + + def unwrapper( + dataType: DataType, + conf: SQLConf, + ioSchema: ScriptTransformationIOSchema): String => Any = { + val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => wrapperConvertException(data => data.getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case ArrayType(tpe, _) => + val un = unwrapper(tpe, conf, ioSchema) + wrapperConvertException(data => { + data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) + .map(un).toSeq + }, converter) + case MapType(keyType, valueType, _) => + val keyUnwrapper = unwrapper(keyType, conf, ioSchema) + val valueUnwrapper = unwrapper(valueType, conf, ioSchema) + wrapperConvertException(data => { + val list = data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) + list.map { kv => + val kvList = kv.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS")) + keyUnwrapper(kvList(0)) -> valueUnwrapper(kvList(1)) + }.toMap + }, converter) + case StructType(fields) => + val unwrappers = fields.map(f => unwrapper(f.dataType, conf, ioSchema)) + wrapperConvertException(data => { + val list = data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATSTRUCTFIELD")) + Row.fromSeq(list.zipWithIndex.map { case (data: String, i: Int) => unwrappers(i)(data) }) + }, converter) + case _ => wrapperConvertException(data => data, converter) + } + } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case _: Exception => null + } + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 196341c26bc9d..cd3934cb30180 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -1,9 +1,12 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k); +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), + float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01'), array(1, 2, 3), map(1, '1'), struct(1, '1')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), + float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02'), array(2, 3, 4), map(1, '1'), struct(1, '1')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), + float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03'), array(3, 4, 5), map(1, '1'), struct(1, '1')) +as t1(a, b, c, d, e, f, g, h, i, j, k, l, m, n); SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -22,8 +25,8 @@ FROM t1; -- support different data type -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l, m, n) USING 'cat' AS ( a string, b boolean, @@ -35,7 +38,10 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( h double, i decimal(38, 18), j timestamp, - k date) + k date, + l array, + m map, + n struct) FROM t1 ) tmp; diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 8e35efdf3fd2a..b3d0b7e32fdad 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -4,10 +4,13 @@ -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k) +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), + float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01'), array(1, 2, 3), map(1, '1'), struct(1, '1')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), + float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02'), array(2, 3, 4), map(1, '1'), struct(1, '1')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), + float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03'), array(3, 4, 5), map(1, '1'), struct(1, '1')) +as t1(a, b, c, d, e, f, g, h, i, j, k, l, m, n) -- !query schema struct<> -- !query output @@ -49,8 +52,8 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l, m, n) USING 'cat' AS ( a string, b boolean, @@ -62,15 +65,18 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( h double, i decimal(38, 18), j timestamp, - k date) + k date, + l array, + m map, + n struct) FROM t1 ) tmp -- !query schema -struct +struct,m:map,n:struct> -- !query output -1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 -2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 -3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 +1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 [1,2,3] {1:"1"} {"col1":1,"col2":"1"} +2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 [2,3,4] {1:"1"} {"col1":1,"col2":"1"} +3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 [3,4,5] {1:"1"} {"col1":1,"col2":"1"} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 6e4362eba025f..a09dae5760940 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -260,22 +260,20 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("f", TimestampType)(), AttributeReference("g", DateType)(), AttributeReference("h", CalendarIntervalType)(), - AttributeReference("i", StringType)(), - AttributeReference("j", StringType)(), + AttributeReference("i", ArrayType(IntegerType))(), + AttributeReference("j", MapType(StringType, IntegerType))(), AttributeReference("k", StringType)(), AttributeReference("l", new SimpleTupleUDT)(), - AttributeReference("m", StringType)(), + AttributeReference("m", StructType( + Seq(StructField("col1", IntegerType), + StructField("col2", StringType))))(), AttributeReference("n", BinaryType)(), AttributeReference("o", BooleanType)()), child = child, ioschema = defaultIOSchema ), - df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, - 'i.cast("string"), - 'j.cast("string"), - 'k.cast("string"), - 'l, 'm.cast("string"), 'n, 'o).collect()) + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, + 'i, 'j, 'k.cast("string"), 'l, 'm, 'n, 'o).collect()) } } From 43d0f24f2c769dc270cf7e5fa2c5c13c32d0a631 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 21 Jul 2020 15:16:32 +0800 Subject: [PATCH 27/42] Revert "support array map struct" This reverts commit cfecc90861ecae94a90e37654412fb31e934d14e. --- .../BaseScriptTransformationExec.scala | 106 ++++---- .../spark/sql/execution/SparkInspectors.scala | 250 ------------------ .../resources/sql-tests/inputs/transform.sql | 20 +- .../sql-tests/results/transform.sql.out | 28 +- .../BaseScriptTransformationSuite.scala | 16 +- 5 files changed, 81 insertions(+), 339 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index c94b02b43398c..b0368cb83114d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets -import java.util.Map.Entry import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -34,8 +33,10 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -46,13 +47,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def ioschema: ScriptTransformationIOSchema protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { - input.map { in: Expression => - in.dataType match { - case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in - case _ => Cast(in, StringType) - .withTimeZone(conf.sessionLocalTimeZone).asInstanceOf[Expression] - } - } + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -182,8 +177,55 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } private lazy val fieldWriters: Seq[String => Any] = output.map { attr => - SparkInspectors.unwrapper(attr.dataType, conf, ioschema) + val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) + attr.dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => wrapperConvertException(data => data.getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case _ => wrapperConvertException(data => data, converter) + } } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case _: Exception => null + } + } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { @@ -206,23 +248,18 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging protected def processRows(): Unit - val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) - protected def processRowsWithoutSerde(): Unit = { val len = inputSchema.length iter.foreach { row => - val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { - case (value, wrapper) => wrapper(value) - } val data = if (len == 0) { ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - buildString(sb, values(0), inputSchema(0)) + sb.append(row.get(0, inputSchema(0))) var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - buildString(sb, values(i), inputSchema(i)) + sb.append(row.get(i, inputSchema(i))) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -232,38 +269,6 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging } } - private def buildString(sb: StringBuilder, obj: Any, dataType: DataType): Unit = { - (obj, dataType) match { - case (arrayList: java.util.ArrayList[_], StructType(fields)) => - (0 until arrayList.size).foreach { i => - if (i > 0) { - sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATSTRUCTFIELD")) - } - buildString(sb, arrayList.get(i), fields(i).dataType) - } - case (list: java.util.List[_], ArrayType(typ, _)) => - (0 until list.size).foreach { i => - if (i > 0) { - sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) - } - buildString(sb, list.get(i), typ) - } - case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val entries = map.entrySet().toArray() - (0 until entries.size).foreach { i => - if (i > 0) { - sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) - } - val entry = entries(i).asInstanceOf[Entry[_, _]] - buildString(sb, entry.getKey, keyType) - sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS")) - buildString(sb, entry.getValue, valueType) - } - case (other, _) => - sb.append(other.toString) - } - } - override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -323,10 +328,7 @@ case class ScriptTransformationIOSchema( object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n"), - ("TOK_TABLEROWFORMATSTRUCTFIELD", "\u0001"), - ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), - ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") + ("TOK_TABLEROWFORMATLINES", "\n") ) val defaultIOSchema = ScriptTransformationIOSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala deleted file mode 100644 index f5a32eb401a7a..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, IntervalUtils, MapData} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -object SparkInspectors { - - def wrapper(dataType: DataType): Any => Any = dataType match { - case ArrayType(tpe, _) => - val wp = wrapper(tpe) - withNullSafe { o => - val array = o.asInstanceOf[ArrayData] - val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(tpe, (_, e) => values.add(wp(e))) - values - } - case MapType(keyType, valueType, _) => - val mt = dataType.asInstanceOf[MapType] - val keyWrapper = wrapper(keyType) - val valueWrapper = wrapper(valueType) - withNullSafe { o => - val map = o.asInstanceOf[MapData] - val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => - jmap.put(keyWrapper(k), valueWrapper(v))) - jmap - } - case StringType => getStringWritable - case IntegerType => getIntWritable - case DoubleType => getDoubleWritable - case BooleanType => getBooleanWritable - case LongType => getLongWritable - case FloatType => getFloatWritable - case ShortType => getShortWritable - case ByteType => getByteWritable - case NullType => (_: Any) => null - case BinaryType => getBinaryWritable - case DateType => getDateWritable - case TimestampType => getTimestampWritable - // TODO decimal precision? - case DecimalType() => getDecimalWritable - case StructType(fields) => - val structType = dataType.asInstanceOf[StructType] - val wrappers = fields.map(f => wrapper(f.dataType)) - withNullSafe { o => - val row = o.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](wrappers.size) - wrappers.zipWithIndex.foreach { - case (wrapper, i) => - val tpe = structType(i).dataType - result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) - } - result - } - case _: UserDefinedType[_] => - val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType - wrapper(sqlType) - } - - private def withNullSafe(f: Any => Any): Any => Any = { - input => - if (input == null) { - null - } else { - f(input) - } - } - - private def getStringWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[UTF8String].toString - } - - private def getIntWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Int] - } - - private def getDoubleWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Double] - } - - private def getBooleanWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Boolean] - } - - private def getLongWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Long] - } - - private def getFloatWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Float] - } - - private def getShortWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Short] - } - - private def getByteWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Byte] - } - - private def getBinaryWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Array[Byte]] - } - - private def getDateWritable(value: Any): Any = - if (value == null) { - null - } else { - DateTimeUtils.toJavaDate(value.asInstanceOf[Int]) - } - - private def getTimestampWritable(value: Any): Any = - if (value == null) { - null - } else { - DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]) - } - - private def getDecimalWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Decimal] - } - - - def unwrapper( - dataType: DataType, - conf: SQLConf, - ioSchema: ScriptTransformationIOSchema): String => Any = { - val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - dataType match { - case StringType => wrapperConvertException(data => data, converter) - case BooleanType => wrapperConvertException(data => data.toBoolean, converter) - case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => wrapperConvertException(data => data.getBytes, converter) - case IntegerType => wrapperConvertException(data => data.toInt, converter) - case ShortType => wrapperConvertException(data => data.toShort, converter) - case LongType => wrapperConvertException(data => data.toLong, converter) - case FloatType => wrapperConvertException(data => data.toFloat, converter) - case DoubleType => wrapperConvertException(data => data.toDouble, converter) - case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) - case DateType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull, converter) - case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull, converter) - case TimestampType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull, converter) - case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull, converter) - case CalendarIntervalType => wrapperConvertException( - data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), - converter) - case udt: UserDefinedType[_] => - wrapperConvertException(data => udt.deserialize(data), converter) - case ArrayType(tpe, _) => - val un = unwrapper(tpe, conf, ioSchema) - wrapperConvertException(data => { - data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) - .map(un).toSeq - }, converter) - case MapType(keyType, valueType, _) => - val keyUnwrapper = unwrapper(keyType, conf, ioSchema) - val valueUnwrapper = unwrapper(valueType, conf, ioSchema) - wrapperConvertException(data => { - val list = data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS")) - list.map { kv => - val kvList = kv.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS")) - keyUnwrapper(kvList(0)) -> valueUnwrapper(kvList(1)) - }.toMap - }, converter) - case StructType(fields) => - val unwrappers = fields.map(f => unwrapper(f.dataType, conf, ioSchema)) - wrapperConvertException(data => { - val list = data.split(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATSTRUCTFIELD")) - Row.fromSeq(list.zipWithIndex.map { case (data: String, i: Int) => unwrappers(i)(data) }) - }, converter) - case _ => wrapperConvertException(data => data, converter) - } - } - - // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null - private val wrapperConvertException: (String => Any, Any => Any) => String => Any = - (f: String => Any, converter: Any => Any) => - (data: String) => converter { - try { - f(data) - } catch { - case _: Exception => null - } - } -} diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index cd3934cb30180..196341c26bc9d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -1,12 +1,9 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), - float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01'), array(1, 2, 3), map(1, '1'), struct(1, '1')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), - float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02'), array(2, 3, 4), map(1, '1'), struct(1, '1')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), - float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03'), array(3, 4, 5), map(1, '1'), struct(1, '1')) -as t1(a, b, c, d, e, f, g, h, i, j, k, l, m, n); +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k); SELECT TRANSFORM(a) USING 'cat' AS (a) @@ -25,8 +22,8 @@ FROM t1; -- support different data type -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l, m, n) +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) USING 'cat' AS ( a string, b boolean, @@ -38,10 +35,7 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( h double, i decimal(38, 18), j timestamp, - k date, - l array, - m map, - n struct) + k date) FROM t1 ) tmp; diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index b3d0b7e32fdad..8e35efdf3fd2a 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -4,13 +4,10 @@ -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), - float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01'), array(1, 2, 3), map(1, '1'), struct(1, '1')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), - float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02'), array(2, 3, 4), map(1, '1'), struct(1, '1')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), - float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03'), array(3, 4, 5), map(1, '1'), struct(1, '1')) -as t1(a, b, c, d, e, f, g, h, i, j, k, l, m, n) +('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k) -- !query schema struct<> -- !query output @@ -52,8 +49,8 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l, m, n) +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) USING 'cat' AS ( a string, b boolean, @@ -65,18 +62,15 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l, m, n FROM ( h double, i decimal(38, 18), j timestamp, - k date, - l array, - m map, - n struct) + k date) FROM t1 ) tmp -- !query schema -struct,m:map,n:struct> +struct -- !query output -1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 [1,2,3] {1:"1"} {"col1":1,"col2":"1"} -2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 [2,3,4] {1:"1"} {"col1":1,"col2":"1"} -3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 [3,4,5] {1:"1"} {"col1":1,"col2":"1"} +1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index a09dae5760940..6e4362eba025f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -260,20 +260,22 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("f", TimestampType)(), AttributeReference("g", DateType)(), AttributeReference("h", CalendarIntervalType)(), - AttributeReference("i", ArrayType(IntegerType))(), - AttributeReference("j", MapType(StringType, IntegerType))(), + AttributeReference("i", StringType)(), + AttributeReference("j", StringType)(), AttributeReference("k", StringType)(), AttributeReference("l", new SimpleTupleUDT)(), - AttributeReference("m", StructType( - Seq(StructField("col1", IntegerType), - StructField("col2", StringType))))(), + AttributeReference("m", StringType)(), AttributeReference("n", BinaryType)(), AttributeReference("o", BooleanType)()), child = child, ioschema = defaultIOSchema ), - df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, - 'i, 'j, 'k.cast("string"), 'l, 'm, 'n, 'o).collect()) + df.select( + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, + 'i.cast("string"), + 'j.cast("string"), + 'k.cast("string"), + 'l, 'm.cast("string"), 'n, 'o).collect()) } } From 9e18fa8d8e75bd938d20923af0f4f3902c09985b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 12:53:17 +0800 Subject: [PATCH 28/42] fix SQLQueryTestSuite --- .../resources/sql-tests/inputs/transform.sql | 54 +++++++++---------- .../sql-tests/results/transform.sql.out | 51 +++++++++--------- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 + .../BaseScriptTransformationSuite.scala | 17 ------ 4 files changed, 53 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 196341c26bc9d..989e358afdc21 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -1,15 +1,14 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k); +('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k, l); SELECT TRANSFORM(a) USING 'cat' AS (a) FROM t1; - -- with non-exist command SELECT TRANSFORM(a) USING 'some_non_existent_command' AS (a) @@ -20,43 +19,45 @@ SELECT TRANSFORM(a) USING 'python some_non_existent_file' AS (a) FROM t1; - --- support different data type -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) +-- common supported data types between no serde and serde transform +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) USING 'cat' AS ( a string, b boolean, c binary, d tinyint, - e smallint, - f long, - g float, - h double, - i decimal(38, 18), - j timestamp, - k date) + e int, + f smallint, + g long, + h float, + i double, + j decimal(38, 18), + k timestamp, + l date) FROM t1 ) tmp; - -- handle schema less +SELECT TRANSFORM(a) +USING 'cat' +FROM t1; + SELECT TRANSFORM(a, b) USING 'cat' FROM t1; --- return null when return string incompatible (no serde) SELECT TRANSFORM(a, b, c) -USING 'cat' as (a int, b int , c int) -FROM ( - SELECT - 1 AS a, - "a" AS b, - CAST(2000 AS timestamp) AS c -) tmp; +USING 'cat' +FROM t1; +-- return null when return string incompatible (no serde) +SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) +USING 'cat' as (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +FROM VALUES +('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i); --- transform can't run with aggregation +-- SPARK-28227: transform can't run with aggregation SELECT TRANSFORM(b, max(a), sum(f)) USING 'cat' AS (a, b) FROM t1 @@ -65,7 +66,6 @@ GROUP BY b; -- transform use MAP MAP a, b USING 'cat' AS (a, b) FROM t1; - -- transform use REDUCE REDUCE a, b USING 'cat' AS (a, b) FROM t1; diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 8e35efdf3fd2a..7c3d36ad9ccaa 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -4,10 +4,10 @@ -- !query CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES -('1', true, unhex('537061726B2053514C'), tinyint(1), smallint(100), array_position(array(3, 2, 1), 1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), -('2', false, unhex('537061726B2053514C'), tinyint(2), smallint(200), array_position(array(3, 2, 1), 2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), -('3', true, unhex('537061726B2053514C'), tinyint(3), smallint(300), array_position(array(3, 2, 1), 1), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k) +('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +as t1(a, b, c, d, e, f, g, h, i, j, k, l) -- !query schema struct<> -- !query output @@ -49,28 +49,29 @@ Subprocess exited with status 2. Error: python: can't open file 'some_non_existe -- !query -SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k FROM ( - SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k) +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) USING 'cat' AS ( a string, b boolean, c binary, d tinyint, - e smallint, - f long, - g float, - h double, - i decimal(38, 18), - j timestamp, - k date) + e int, + f smallint, + g long, + h float, + i double, + j decimal(38, 18), + k timestamp, + l date) FROM t1 ) tmp -- !query schema -struct +struct -- !query output -1 true Spark SQL 1 100 3 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 -2 false Spark SQL 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 -3 true Spark SQL 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 +1 true Spark SQL 1 1 100 1 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 -- !query @@ -86,18 +87,14 @@ struct -- !query -SELECT TRANSFORM(a, b, c) -USING 'cat' as (a int, b int , c int) -FROM ( - SELECT - 1 AS a, - "a" AS b, - CAST(2000 AS timestamp) AS c -) tmp +SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) +USING 'cat' as (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +FROM VALUES +('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i) -- !query schema -struct +struct -- !query output -1 NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL NULL -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 36d7eeef44868..8f18468e36bcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -258,6 +258,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER") } + // SPARK-32106 Since we add SQL test 'transform.sql' will use `cat` command, + // here we need to check command available assume(TestUtils.testCommandAvailable("/bin/bash")) val input = fileToString(new File(testCase.inputFile)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 6e4362eba025f..6519f8c12cfbf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -278,23 +278,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU 'l, 'm.cast("string"), 'n, 'o).collect()) } } - - test("SPARK-32106: TRANSFORM should return null when return string incompatible") { - checkAnswer( - sql( - """ - |SELECT TRANSFORM(a, b, c) - |USING 'cat' as (a int, b int , c int) - |FROM ( - |SELECT - |1 AS a, - |"a" AS b, - |CAST(2000 AS timestamp) AS c - |) tmp - """.stripMargin), - identity, - Row(1, null, null) :: Nil) - } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From 9537d9bc67a2d9c19dcfb8831b196ffe2a2c417d Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 13:07:36 +0800 Subject: [PATCH 29/42] address comment --- .../sql-tests/results/transform.sql.out | 25 ++++++++++++++++++- .../BaseScriptTransformationSuite.scala | 6 +---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 7c3d36ad9ccaa..ad3b24a40da00 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 12 -- !query @@ -74,6 +74,17 @@ struct +-- !query output +java.lang.ArrayIndexOutOfBoundsException +1 + + -- !query SELECT TRANSFORM(a, b) USING 'cat' @@ -86,6 +97,18 @@ struct 3 true +-- !query +SELECT TRANSFORM(a, b, c) +USING 'cat' +FROM t1 +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + -- !query SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) USING 'cat' as (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 6519f8c12cfbf..567603e259089 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -182,11 +182,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ), df.select( 'a.cast("string").as("key"), - concat_ws("\t", - 'b.cast("string"), - 'c.cast("string"), - decimalToString('d), - 'e.cast("string"))).collect()) + 'b.cast("string").as("value")).collect()) } } From 52274418027b3a46eb0b287ea677616ab4b567e4 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 13:07:40 +0800 Subject: [PATCH 30/42] Update BaseScriptTransformationExec.scala --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index b0368cb83114d..6127f83a9218d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -111,10 +111,13 @@ trait BaseScriptTransformationExec extends UnaryExecNode { .zip(fieldWriters) .map { case (data, writer) => writer(data) }) } else { + // In schema less mode, hive default serde will choose first two output column as output + // if output column size less then 2, it will throw ArrayIndexOutOfBoundsException. + // Here we change spark's behavior same as hive's default serde prevLine: String => new GenericInternalRow( - prevLine.split(outputRowFormat, 2) - .map(CatalystTypeConverters.convertToCatalyst)) + prevLine.split(outputRowFormat).slice(0, 2) + .map(CatalystTypeConverters.createToCatalystConverter(StringType))) } override def hasNext: Boolean = { From 670f21bbd49267a7d8b2927fa166f05098291f5b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 13:20:18 +0800 Subject: [PATCH 31/42] Update BaseScriptTransformationSuite.scala --- .../BaseScriptTransformationSuite.scala | 147 +++++++++--------- 1 file changed, 72 insertions(+), 75 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 567603e259089..131008d8619f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -28,11 +28,12 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -156,14 +157,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (isHive23OrSpark) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } - checkAnswer( df, (child: SparkPlan) => createScriptTransformationExec( @@ -206,72 +199,76 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU test("SPARK-32106: TRANSFORM should support all data types as input (no serde)") { assume(TestUtils.testCommandAvailable("python")) - withTempView("v") { - val df = Seq( - (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), - new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), - new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), - new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, - struct('a, 'b).as("m"), unhex('a).as("n"), lit(true).as("o") - ) // Note column d's data type is Decimal(38, 18) - - // Can't support convert script output data to ArrayType/MapType/StructType now, - // return these column still as string. - // For UserDefinedType, if user defined deserialize method to support convert string - // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we - // will return null value as column. - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("a").expr, - df.col("b").expr, - df.col("c").expr, - df.col("d").expr, - df.col("e").expr, - df.col("f").expr, - df.col("g").expr, - df.col("h").expr, - df.col("i").expr, - df.col("j").expr, - df.col("k").expr, - df.col("l").expr, - df.col("m").expr, - df.col("n").expr, - df.col("o").expr), - script = "cat", - output = Seq( - AttributeReference("a", IntegerType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", ByteType)(), - AttributeReference("e", DecimalType(38, 18))(), - AttributeReference("f", TimestampType)(), - AttributeReference("g", DateType)(), - AttributeReference("h", CalendarIntervalType)(), - AttributeReference("i", StringType)(), - AttributeReference("j", StringType)(), - AttributeReference("k", StringType)(), - AttributeReference("l", new SimpleTupleUDT)(), - AttributeReference("m", StringType)(), - AttributeReference("n", BinaryType)(), - AttributeReference("o", BooleanType)()), - child = child, - ioschema = defaultIOSchema - ), - df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, - 'i.cast("string"), - 'j.cast("string"), - 'k.cast("string"), - 'l, 'm.cast("string"), 'n, 'o).collect()) + Array(false, true).foreach { java8AapiEnable => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8AapiEnable.toString) { + withTempView("v") { + val df = Seq( + (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), + Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), + (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), + Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), + (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), + Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, + struct('a, 'b).as("m"), unhex('a).as("n"), lit(true).as("o") + ) // Note column d's data type is Decimal(38, 18) + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string. + // For UserDefinedType, if user defined deserialize method to support convert string + // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we + // will return null value as column. + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr, + df.col("k").expr, + df.col("l").expr, + df.col("m").expr, + df.col("n").expr, + df.col("o").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", ByteType)(), + AttributeReference("e", DecimalType(38, 18))(), + AttributeReference("f", TimestampType)(), + AttributeReference("g", DateType)(), + AttributeReference("h", CalendarIntervalType)(), + AttributeReference("i", StringType)(), + AttributeReference("j", StringType)(), + AttributeReference("k", StringType)(), + AttributeReference("l", new SimpleTupleUDT)(), + AttributeReference("m", StringType)(), + AttributeReference("n", BinaryType)(), + AttributeReference("o", BooleanType)()), + child = child, + ioschema = defaultIOSchema + ), + df.select( + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, + 'i.cast("string"), + 'j.cast("string"), + 'k.cast("string"), + 'l, 'm.cast("string"), 'n, 'o).collect()) + } + } } } } From ce8184a89a14ad9431327cadf32317fd646f9f92 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 17:21:29 +0800 Subject: [PATCH 32/42] address comment --- .../sql/catalyst/parser/AstBuilder.scala | 78 +++++++++++++++++- .../BaseScriptTransformationExec.scala | 5 +- .../spark/sql/execution/SparkSqlParser.scala | 82 ------------------- .../resources/sql-tests/inputs/transform.sql | 47 +++++++---- .../sql-tests/results/transform.sql.out | 56 +++++++++---- .../BaseScriptTransformationSuite.scala | 46 +++++------ .../SparkScriptTransformationSuite.scala | 24 ++++++ .../HiveScriptTransformationSuite.scala | 76 +++++++++++++++-- 8 files changed, 267 insertions(+), 147 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f8261c293782d..6a9dd44734276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.random.RandomSampler @@ -745,7 +746,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create a (Hive based) [[ScriptInputOutputSchema]]. + * Create a [[ScriptInputOutputSchema]]. */ protected def withScriptIOSchema( ctx: ParserRuleContext, @@ -754,7 +755,80 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging outRowFormat: RowFormatContext, recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { - throw new ParseException("Script Transform is not supported", ctx) + if (recordWriter != null || recordReader != null) { + // TODO: what does this message mean? + throw new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) + } + + // Decode and input/output format. + type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + def format( + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): Format = fmt match { + case c: RowFormatDelimitedContext => + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + + val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) + + (entries, None, Seq.empty, None) + + case c: RowFormatContext if !conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => + throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + + case c: RowFormatSerdeContext => + // Use a serde format. + val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) + + // SPARK-10310: Special cases LazySimpleSerDe + val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { + Option(conf.getConfString(configKey, defaultConfigValue)) + } else { + None + } + (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => + // Use default (serde) format. + val name = conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + val props = Seq("field.delim" -> "\t") + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) + (Nil, Option(name), props, recordHandler) + + // SPARK-32106: When there is no definition about format, we return empty result + // to use a built-in default Serde in SparkScriptTransformationExec. + case null => + (Nil, None, Seq.empty, None) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = + format( + inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") + + val (outFormat, outSerdeClass, outSerdeProps, writer) = + format( + outRowFormat, "hive.script.recordwriter", + "org.apache.hadoop.hive.ql.exec.TextRecordWriter") + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 6127f83a9218d..ba5bc83af512f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -114,10 +114,11 @@ trait BaseScriptTransformationExec extends UnaryExecNode { // In schema less mode, hive default serde will choose first two output column as output // if output column size less then 2, it will throw ArrayIndexOutOfBoundsException. // Here we change spark's behavior same as hive's default serde + val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType) prevLine: String => new GenericInternalRow( prevLine.split(outputRowFormat).slice(0, 2) - .map(CatalystTypeConverters.createToCatalystConverter(StringType))) + .map(kvWriter)) } override def hasNext: Boolean = { @@ -226,7 +227,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { try { f(data) } catch { - case _: Exception => null + case NonFatal(_) => null } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 37bd3022ba4bc..a813b04700543 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType /** @@ -664,87 +663,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } - /** - * Create a [[ScriptInputOutputSchema]]. - */ - override protected def withScriptIOSchema( - ctx: ParserRuleContext, - inRowFormat: RowFormatContext, - recordWriter: Token, - outRowFormat: RowFormatContext, - recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = { - if (recordWriter != null || recordReader != null) { - // TODO: what does this message mean? - throw new ParseException( - "Unsupported operation: Used defined record reader/writer classes.", ctx) - } - - // Decode and input/output format. - type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - def format( - fmt: RowFormatContext, - configKey: String, - defaultConfigValue: String): Format = fmt match { - case c: RowFormatDelimitedContext => - // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema - // expects a seq of pairs in which the old parsers' token names are used as keys. - // Transforming the result of visitRowFormatDelimited would be quite a bit messier than - // retrieving the key value pairs ourselves. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq - } - val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ - entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ - entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) - - (entries, None, Seq.empty, None) - - case c: RowFormatSerdeContext => - // Use a serde format. - val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) - - // SPARK-10310: Special cases LazySimpleSerDe - val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { - Option(conf.getConfString(configKey, defaultConfigValue)) - } else { - None - } - (Seq.empty, Option(name), props.toSeq, recordHandler) - - case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => - // Use default (serde) format. - val name = conf.getConfString("hive.script.serde", - "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - val props = Seq("field.delim" -> "\t") - val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) - (Nil, Option(name), props, recordHandler) - - // SPARK-32106: When there is no definition about format, we return empty result - // to use a built-in default Serde in SparkScriptTransformationExec. - case null => - (Nil, None, Seq.empty, None) - } - - val (inFormat, inSerdeClass, inSerdeProps, reader) = - format( - inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") - - val (outFormat, outSerdeClass, outSerdeProps, writer) = - format( - outRowFormat, "hive.script.recordwriter", - "org.apache.hadoop.hive.ql.exec.TextRecordWriter") - - ScriptInputOutputSchema( - inFormat, outFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - reader, writer, - schemaLess) - } - /** * Create a clause for DISTRIBUTE BY. */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 989e358afdc21..d7ad3d4d73a24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -1,23 +1,23 @@ -- Test data. -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES ('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), ('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), ('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k, l); +AS t(a, b, c, d, e, f, g, h, i, j, k, l); SELECT TRANSFORM(a) USING 'cat' AS (a) -FROM t1; +FROM t; -- with non-exist command SELECT TRANSFORM(a) USING 'some_non_existent_command' AS (a) -FROM t1; +FROM t; -- with non-exist file SELECT TRANSFORM(a) USING 'python some_non_existent_file' AS (a) -FROM t1; +FROM t; -- common supported data types between no serde and serde transform SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( @@ -35,38 +35,55 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( j decimal(38, 18), k timestamp, l date) - FROM t1 + FROM t +) tmp; + +-- common supported data types between no serde and serde transform +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) + USING 'cat' AS ( + a string, + b string, + c string, + d string, + e string, + f string, + g string, + h string, + i string, + j string, + k string, + l string) + FROM t ) tmp; -- handle schema less SELECT TRANSFORM(a) USING 'cat' -FROM t1; +FROM t; SELECT TRANSFORM(a, b) USING 'cat' -FROM t1; +FROM t; SELECT TRANSFORM(a, b, c) USING 'cat' -FROM t1; +FROM t; -- return null when return string incompatible (no serde) SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) -USING 'cat' as (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +USING 'cat' AS (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) FROM VALUES ('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i); -- SPARK-28227: transform can't run with aggregation SELECT TRANSFORM(b, max(a), sum(f)) USING 'cat' AS (a, b) -FROM t1 +FROM t GROUP BY b; -- transform use MAP -MAP a, b USING 'cat' AS (a, b) FROM t1; +MAP a, b USING 'cat' AS (a, b) FROM t; -- transform use REDUCE -REDUCE a, b USING 'cat' AS (a, b) FROM t1; - - +REDUCE a, b USING 'cat' AS (a, b) FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index ad3b24a40da00..37214b90b86fd 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,13 +1,13 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 13 -- !query -CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES ('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), ('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), ('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) -as t1(a, b, c, d, e, f, g, h, i, j, k, l) +AS t(a, b, c, d, e, f, g, h, i, j, k, l) -- !query schema struct<> -- !query output @@ -17,7 +17,7 @@ struct<> -- !query SELECT TRANSFORM(a) USING 'cat' AS (a) -FROM t1 +FROM t -- !query schema struct -- !query output @@ -29,7 +29,7 @@ struct -- !query SELECT TRANSFORM(a) USING 'some_non_existent_command' AS (a) -FROM t1 +FROM t -- !query schema struct<> -- !query output @@ -40,7 +40,7 @@ Subprocess exited with status 127. Error: /bin/bash: some_non_existent_command: -- !query SELECT TRANSFORM(a) USING 'python some_non_existent_file' AS (a) -FROM t1 +FROM t -- !query schema struct<> -- !query output @@ -64,7 +64,7 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( j decimal(38, 18), k timestamp, l date) - FROM t1 + FROM t ) tmp -- !query schema struct @@ -74,10 +74,36 @@ struct +-- !query output +1 true Spark SQL 1 1 100 1 1.0 1.0 1 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 2 200 2 2.0 2.0 2 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 3 300 3 3.0 3.0 3 1997-02-10 17:32:01 2000-04-03 + + -- !query SELECT TRANSFORM(a) USING 'cat' -FROM t1 +FROM t -- !query schema struct<> -- !query output @@ -88,7 +114,7 @@ java.lang.ArrayIndexOutOfBoundsException -- !query SELECT TRANSFORM(a, b) USING 'cat' -FROM t1 +FROM t -- !query schema struct -- !query output @@ -100,7 +126,7 @@ struct -- !query SELECT TRANSFORM(a, b, c) USING 'cat' -FROM t1 +FROM t -- !query schema struct -- !query output @@ -111,7 +137,7 @@ struct -- !query SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) -USING 'cat' as (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +USING 'cat' AS (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) FROM VALUES ('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i) -- !query schema @@ -123,7 +149,7 @@ NULL NULL NULL NULL NULL NULL NULL NULL NULL -- !query SELECT TRANSFORM(b, max(a), sum(f)) USING 'cat' AS (a, b) -FROM t1 +FROM t GROUP BY b -- !query schema struct<> @@ -135,13 +161,13 @@ mismatched input 'GROUP' expecting {, ';'}(line 4, pos 0) == SQL == SELECT TRANSFORM(b, max(a), sum(f)) USING 'cat' AS (a, b) -FROM t1 +FROM t GROUP BY b ^^^ -- !query -MAP a, b USING 'cat' AS (a, b) FROM t1 +MAP a, b USING 'cat' AS (a, b) FROM t -- !query schema struct -- !query output @@ -151,7 +177,7 @@ struct -- !query -REDUCE a, b USING 'cat' AS (a, b) FROM t1 +REDUCE a, b USING 'cat' AS (a, b) FROM t -- !query schema struct -- !query output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 131008d8619f0..b8ee9daa82d10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -203,18 +203,18 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8AapiEnable.toString) { withTempView("v") { val df = Seq( - (1, "1", 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (2, "2", 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (3, "3", 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, - struct('a, 'b).as("m"), unhex('a).as("n"), lit(true).as("o") + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, + struct('a, 'b).as("n"), unhex('a).as("o"), lit(true).as("p") ) // Note column d's data type is Decimal(38, 18) // Can't support convert script output data to ArrayType/MapType/StructType now, @@ -240,33 +240,33 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU df.col("l").expr, df.col("m").expr, df.col("n").expr, - df.col("o").expr), + df.col("o").expr, + df.col("p").expr), script = "cat", output = Seq( AttributeReference("a", IntegerType)(), AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", ByteType)(), - AttributeReference("e", DecimalType(38, 18))(), - AttributeReference("f", TimestampType)(), - AttributeReference("g", DateType)(), - AttributeReference("h", CalendarIntervalType)(), - AttributeReference("i", StringType)(), + AttributeReference("c", FloatType)(), + AttributeReference("d", DoubleType)(), + AttributeReference("e", ByteType)(), + AttributeReference("f", DecimalType(38, 18))(), + AttributeReference("g", TimestampType)(), + AttributeReference("h", DateType)(), + AttributeReference("i", CalendarIntervalType)(), AttributeReference("j", StringType)(), AttributeReference("k", StringType)(), - AttributeReference("l", new SimpleTupleUDT)(), - AttributeReference("m", StringType)(), - AttributeReference("n", BinaryType)(), - AttributeReference("o", BooleanType)()), + AttributeReference("l", StringType)(), + AttributeReference("m", new SimpleTupleUDT)(), + AttributeReference("n", StringType)(), + AttributeReference("o", BinaryType)(), + AttributeReference("p", BooleanType)()), child = child, ioschema = defaultIOSchema ), df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, - 'i.cast("string"), - 'j.cast("string"), - 'k.cast("string"), - 'l, 'm.cast("string"), 'n, 'o).collect()) + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, + 'j.cast("string"), 'k.cast("string"), + 'l.cast("string"), 'm, 'n.cast("string"), 'o, 'p).collect()) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 68f070a85a12d..cb912dc4973e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import org.apache.spark.TestUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.test.SharedSparkSession class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { + import testImplicits._ override def isHive23OrSpark: Boolean = true @@ -38,4 +41,25 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with ioschema = ioschema ) } + + test("SPARK-32106: TRANSFORM with serde without hive should throw exception") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + df.createTempView("v") + + val e = intercept[ParseException] { + sql( + """ + |SELECT TRANSFORM (a) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |USING 'cat' AS (a) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |FROM v + """.stripMargin) + }.getMessage + assert(e.contains("TRANSFORM with serde is only supported in hive mode")) + } + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index ae2d581a73b04..0dddfa39d6378 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -77,7 +77,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(uncaughtExceptionHandler.exception.isEmpty) } - test("script transformation should not swallow errors from upstream operators (with serde)") { + test("script transformation should not swallow errors from upstream operators (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") @@ -98,7 +98,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-14400 script transformation should fail for bad script command") { + test("SPARK-14400 script transformation should fail for bad script command (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") @@ -117,7 +117,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-24339 verify the result after pruning the unused columns") { + test("SPARK-24339 verify the result after pruning the unused columns (hive serde)") { val rowsDf = Seq( ("Bob", 16, 176), ("Alice", 32, 164), @@ -137,7 +137,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-30973: TRANSFORM should wait for the termination of the script (with serde)") { + test("SPARK-30973: TRANSFORM should wait for the termination of the script (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") @@ -185,7 +185,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T } } - test("SPARK-32106: TRANSFORM support complex data types as input and ouput type (hive serde)") { + test("SPARK-32106: TRANSFORM supports complex data types type (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) withTempView("v") { val df = Seq( @@ -218,7 +218,67 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T } } - test("SPARK-32106: TRANSFORM don't support CalenderIntervalType/UserDefinedType (hive serde)") { + test("SPARK-32106: TRANSFORM supports complex data types end to end (hive serde) ") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, "1", Array(0, 1, 2), Map("a" -> 1)), + (2, "2", Array(3, 4, 5), Map("b" -> 2))) + .toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + df.createTempView("v") + + // Hive serde support ArrayType/MapType/StructType as input and output data type + val query = sql( + """ + |SELECT TRANSFORM (c, d, e) + |USING 'cat' AS (c array, d map, e struct) + |FROM v + """.stripMargin) + checkAnswer(query, identity, df.select('c, 'd, 'e).collect()) + } + } + + test("SPARK-32106: TRANSFORM doesn't support CalenderIntervalType/UserDefinedType (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3))), + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3)))) + .toDF("a", "b", "c") + df.createTempView("v") + + val e1 = intercept[SparkException] { + val plan = createScriptTransformationExec( + input = Seq(df.col("a").expr, df.col("b").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", CalendarIntervalType)()), + child = df.queryExecution.sparkPlan, + ioschema = serdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) + + val e2 = intercept[SparkException] { + val plan = createScriptTransformationExec( + input = Seq(df.col("a").expr, df.col("c").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("c", new TestUDT.MyDenseVectorUDT)()), + child = df.queryExecution.sparkPlan, + ioschema = serdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e2.getMessage.contains( + "scala.MatchError: org.apache.spark.sql.types.TestUDT$MyDenseVectorUDT")) + } + } + + test("SPARK-32106: TRANSFORM doesn't support" + + " CalenderIntervalType/UserDefinedType end to end (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) withTempView("v") { val df = Seq( @@ -227,7 +287,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T .toDF("a", "b", "c") df.createTempView("v") - val e1 = intercept[Exception] { + val e1 = intercept[SparkException] { sql( """ |SELECT TRANSFORM(a, b) USING 'cat' AS (a, b) @@ -236,7 +296,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T } assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) - val e2 = intercept[Exception] { + val e2 = intercept[SparkException] { sql( """ |SELECT TRANSFORM(a, c) USING 'cat' AS (a, c) From 4615733b35dc9a2f1d36fc654774e8e3a525831f Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 17:27:10 +0800 Subject: [PATCH 33/42] Update SparkScriptTransformationSuite.scala --- .../spark/sql/execution/SparkScriptTransformationSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index cb912dc4973e5..183390a846056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -60,6 +60,5 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with }.getMessage assert(e.contains("TRANSFORM with serde is only supported in hive mode")) } - } } From 08d97c879412a68b199473034e11eb026cdf690c Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 18:02:24 +0800 Subject: [PATCH 34/42] throw exception when complex data type --- .../BaseScriptTransformationExec.scala | 3 ++ .../SparkScriptTransformationSuite.scala | 41 ++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index ba5bc83af512f..3d44e2596fbae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -216,6 +216,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode { converter) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) + case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => + throw new SparkException("TRANSFORM without serde don't support" + + " ArrayType/MapType/StructType as output data type") case _ => wrapperConvertException(data => data, converter) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 183390a846056..0826ecceec931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.TestUtils +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.test.SharedSparkSession @@ -61,4 +61,43 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with assert(e.contains("TRANSFORM with serde is only supported in hive mode")) } } + + test("TRANSFORM don't support ArrayType/MapType/StructType as output data type (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + // check for ArrayType + val e1 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(a) + |USING 'cat' AS (a array) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e1.contains("TRANSFORM without serde don't support" + + " ArrayType/MapType/StructType as output data type")) + + // check for MapType + val e2 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(b) + |USING 'cat' AS (b map) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e2.contains("TRANSFORM without serde don't support" + + " ArrayType/MapType/StructType as output data type")) + + // check for StructType + val e3 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(c) + |USING 'cat' AS (c struct) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e3.contains("TRANSFORM without serde don't support" + + " ArrayType/MapType/StructType as output data type")) + } } From 33923b671f3c6b7084852e64fba2c8a8bcb63235 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 18:24:50 +0800 Subject: [PATCH 35/42] https://github.com/apache/spark/pull/29085#discussion_r458676081 --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 3d44e2596fbae..7eab3fcd8c201 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -113,7 +113,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } else { // In schema less mode, hive default serde will choose first two output column as output // if output column size less then 2, it will throw ArrayIndexOutOfBoundsException. - // Here we change spark's behavior same as hive's default serde + // Here we change spark's behavior same as hive's default serde. + // But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this + // to keep spark and hive behavior same in SPARK-32388 val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType) prevLine: String => new GenericInternalRow( From f5ec6560e13c92f76ead41cc3f47f76e61511031 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 18:33:23 +0800 Subject: [PATCH 36/42] https://github.com/apache/spark/pull/29085#discussion_r458687735 --- .../BaseScriptTransformationSuite.scala | 149 ++++++++++-------- 1 file changed, 85 insertions(+), 64 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index b8ee9daa82d10..c1f723d324d7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -198,75 +198,96 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } test("SPARK-32106: TRANSFORM should support all data types as input (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), + Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), + (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), + Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), + (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), + Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m") + .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, + struct('a, 'b).as("n"), unhex('a).as("o"), lit(true).as("p") + ) // Note column d's data type is Decimal(38, 18) + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string. + // For UserDefinedType, if user defined deserialize method to support convert string + // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we + // will return null value as column. + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr, + df.col("k").expr, + df.col("l").expr, + df.col("m").expr, + df.col("n").expr, + df.col("o").expr, + df.col("p").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", FloatType)(), + AttributeReference("d", DoubleType)(), + AttributeReference("e", ByteType)(), + AttributeReference("f", DecimalType(38, 18))(), + AttributeReference("g", TimestampType)(), + AttributeReference("h", DateType)(), + AttributeReference("i", CalendarIntervalType)(), + AttributeReference("j", StringType)(), + AttributeReference("k", StringType)(), + AttributeReference("l", StringType)(), + AttributeReference("m", new SimpleTupleUDT)(), + AttributeReference("n", StringType)(), + AttributeReference("o", BinaryType)(), + AttributeReference("p", BooleanType)()), + child = child, + ioschema = defaultIOSchema + ), + df.select( + 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, + 'j.cast("string"), 'k.cast("string"), + 'l.cast("string"), 'm, 'n.cast("string"), 'o, 'p).collect()) + } + } + + + test("SPARK-32106: TRANSFORM should respect DATETIME_JAVA8API_ENABLED (no serde)") { assume(TestUtils.testCommandAvailable("python")) Array(false, true).foreach { java8AapiEnable => withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8AapiEnable.toString) { withTempView("v") { val df = Seq( - (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), - new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), - new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), - new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, - struct('a, 'b).as("n"), unhex('a).as("o"), lit(true).as("p") - ) // Note column d's data type is Decimal(38, 18) - - // Can't support convert script output data to ArrayType/MapType/StructType now, - // return these column still as string. - // For UserDefinedType, if user defined deserialize method to support convert string - // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we - // will return null value as column. - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("a").expr, - df.col("b").expr, - df.col("c").expr, - df.col("d").expr, - df.col("e").expr, - df.col("f").expr, - df.col("g").expr, - df.col("h").expr, - df.col("i").expr, - df.col("j").expr, - df.col("k").expr, - df.col("l").expr, - df.col("m").expr, - df.col("n").expr, - df.col("o").expr, - df.col("p").expr), - script = "cat", - output = Seq( - AttributeReference("a", IntegerType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", FloatType)(), - AttributeReference("d", DoubleType)(), - AttributeReference("e", ByteType)(), - AttributeReference("f", DecimalType(38, 18))(), - AttributeReference("g", TimestampType)(), - AttributeReference("h", DateType)(), - AttributeReference("i", CalendarIntervalType)(), - AttributeReference("j", StringType)(), - AttributeReference("k", StringType)(), - AttributeReference("l", StringType)(), - AttributeReference("m", new SimpleTupleUDT)(), - AttributeReference("n", StringType)(), - AttributeReference("o", BinaryType)(), - AttributeReference("p", BooleanType)()), - child = child, - ioschema = defaultIOSchema - ), - df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, - 'j.cast("string"), 'k.cast("string"), - 'l.cast("string"), 'm, 'n.cast("string"), 'o, 'p).collect()) + (new Timestamp(1), new Date(2020, 7, 1)), + (new Timestamp(2), new Date(2020, 7, 2)), + (new Timestamp(3), new Date(2020, 7, 3)) + ).toDF("a", "b") + df.createTempView("v") + + val query = sql( + """ + |SELECT TRANSFORM (a, b) + |USING 'cat' AS (a timestamp, b date) + |FROM v + """.stripMargin) + checkAnswer(query, identity, df.select('a, 'b).collect()) } } } From 7916d725ec133010edc7d2e5ba4f83b097933b8c Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 19:06:08 +0800 Subject: [PATCH 37/42] https://github.com/apache/spark/pull/29085#discussion_r458692902 --- .../sql/catalyst/parser/AstBuilder.scala | 43 +-------- .../spark/sql/execution/SparkSqlParser.scala | 91 +++++++++++++++++++ 2 files changed, 95 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6a9dd44734276..e57a0b666c321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.random.RandomSampler @@ -755,19 +754,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging outRowFormat: RowFormatContext, recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { - if (recordWriter != null || recordReader != null) { - // TODO: what does this message mean? - throw new ParseException( - "Unsupported operation: Used defined record reader/writer classes.", ctx) - } - // Decode and input/output format. type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - def format( - fmt: RowFormatContext, - configKey: String, - defaultConfigValue: String): Format = fmt match { + def format(fmt: RowFormatContext): Format = fmt match { case c: RowFormatDelimitedContext => // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. @@ -785,28 +775,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging (entries, None, Seq.empty, None) - case c: RowFormatContext if !conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => - throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) - case c: RowFormatSerdeContext => - // Use a serde format. - val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) - - // SPARK-10310: Special cases LazySimpleSerDe - val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { - Option(conf.getConfString(configKey, defaultConfigValue)) - } else { - None - } - (Seq.empty, Option(name), props.toSeq, recordHandler) - - case null if conf.getConf(CATALOG_IMPLEMENTATION).equals("hive") => - // Use default (serde) format. - val name = conf.getConfString("hive.script.serde", - "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - val props = Seq("field.delim" -> "\t") - val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) - (Nil, Option(name), props, recordHandler) + throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) // SPARK-32106: When there is no definition about format, we return empty result // to use a built-in default Serde in SparkScriptTransformationExec. @@ -814,14 +784,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging (Nil, None, Seq.empty, None) } - val (inFormat, inSerdeClass, inSerdeProps, reader) = - format( - inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") + val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat) - val (outFormat, outSerdeClass, outSerdeProps, writer) = - format( - outRowFormat, "hive.script.recordwriter", - "org.apache.hadoop.hive.ql.exec.TextRecordWriter") + val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat) ScriptInputOutputSchema( inFormat, outFormat, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a813b04700543..c55b90b39c872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType /** @@ -663,6 +664,96 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + /** + * Create a hive serde [[ScriptInputOutputSchema]]. + */ + override protected def withScriptIOSchema( + ctx: ParserRuleContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + if (recordWriter != null || recordReader != null) { + // TODO: what does this message mean? + throw new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) + } + + if (!conf.getConf(CATALOG_IMPLEMENTATION).equals("hive")) { + super.withScriptIOSchema( + ctx, + inRowFormat, + recordWriter, + outRowFormat, + recordReader, + schemaLess) + } else { + + // Decode and input/output format. + type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + def format( + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): Format = fmt match { + case c: RowFormatDelimitedContext => + // TODO we should use visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + + val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) + + (entries, None, Seq.empty, None) + + case c: RowFormatSerdeContext => + // Use a serde format. + val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) + + // SPARK-10310: Special cases LazySimpleSerDe + val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { + Option(conf.getConfString(configKey, defaultConfigValue)) + } else { + None + } + (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null => + // Use default (serde) format. + val name = conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + val props = Seq("field.delim" -> "\t") + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) + (Nil, Option(name), props, recordHandler) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = + format( + inRowFormat, "hive.script.recordreader", + "org.apache.hadoop.hive.ql.exec.TextRecordReader") + + val (outFormat, outSerdeClass, outSerdeProps, writer) = + format( + outRowFormat, "hive.script.recordwriter", + "org.apache.hadoop.hive.ql.exec.TextRecordWriter") + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + } + /** * Create a clause for DISTRIBUTE BY. */ From a769aa7b86e7c46b90bb3bb5a8b723e4dd8b85f7 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 22:37:43 +0800 Subject: [PATCH 38/42] address comment --- .../sql/catalyst/parser/AstBuilder.scala | 39 ++++---- .../sql/catalyst/parser/PlanParserSuite.scala | 94 ++++++++++++++++++- .../BaseScriptTransformationExec.scala | 11 +-- .../spark/sql/execution/SparkSqlParser.scala | 20 +--- .../resources/sql-tests/inputs/transform.sql | 2 +- .../BaseScriptTransformationSuite.scala | 1 - .../SparkScriptTransformationSuite.scala | 14 +-- .../HiveScriptTransformationSuite.scala | 18 ++-- 8 files changed, 138 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e57a0b666c321..3feea66e62a62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -744,6 +744,27 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging selectClause.hints.asScala.foldRight(withWindow)(withHints) } + // Decode and input/output format. + type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): Format = { + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + + val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", ctx.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) + + (entries, None, Seq.empty, None) + } + /** * Create a [[ScriptInputOutputSchema]]. */ @@ -754,26 +775,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging outRowFormat: RowFormatContext, recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { - // Decode and input/output format. - type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) def format(fmt: RowFormatContext): Format = fmt match { case c: RowFormatDelimitedContext => - // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema - // expects a seq of pairs in which the old parsers' token names are used as keys. - // Transforming the result of visitRowFormatDelimited would be quite a bit messier than - // retrieving the key value pairs ourselves. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq - } - - val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ - entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ - entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) - - (entries, None, Seq.empty, None) + getRowFormatDelimited(c) case c: RowFormatSerdeContext => throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 88afcb10d9c20..8d44a33d882a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType, StringType} /** * Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]]. @@ -1031,4 +1031,96 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b)) assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b)) } + + test("SPARK-32106: TRANSFORM without serde") { + // verify schema less + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, true)) + ) + + // verify without output schema + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' AS (a, b, c) + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, false))) + + // verify with output schema + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' AS (a int, b string, c long) + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", LongType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, false))) + + // verify with ROW FORMAT DELIMETED + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '\t' + |COLLECTION ITEMS TERMINATED BY '\u0002' + |MAP KEYS TERMINATED BY '\u0003' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'null' + |USING 'cat' AS (a, b, c) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '\t' + |COLLECTION ITEMS TERMINATED BY '\u0004' + |MAP KEYS TERMINATED BY '\u0005' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'NULL' + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema( + Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"), + ("TOK_TABLEROWFORMATCOLLITEMS", "'\u0002'"), + ("TOK_TABLEROWFORMATMAPKEYS", "'\u0003'"), + ("TOK_TABLEROWFORMATLINES", "'\\n'"), + ("TOK_TABLEROWFORMATNULL", "'null'")), + Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"), + ("TOK_TABLEROWFORMATCOLLITEMS", "'\u0004'"), + ("TOK_TABLEROWFORMATMAPKEYS", "'\u0005'"), + ("TOK_TABLEROWFORMATLINES", "'\\n'"), + ("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None, + List.empty, List.empty, None, None, false))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 7eab3fcd8c201..2fc7deb0858af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -108,7 +108,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { prevLine: String => new GenericInternalRow( prevLine.split(outputRowFormat) - .zip(fieldWriters) + .zip(outputFieldWriters) .map { case (data, writer) => writer(data) }) } else { // In schema less mode, hive default serde will choose first two output column as output @@ -182,7 +182,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - private lazy val fieldWriters: Seq[String => Any] = output.map { attr => + private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) attr.dataType match { case StringType => wrapperConvertException(data => data, converter) @@ -218,10 +218,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode { converter) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) - case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => - throw new SparkException("TRANSFORM without serde don't support" + - " ArrayType/MapType/StructType as output data type") - case _ => wrapperConvertException(data => data, converter) + case dt => + throw new SparkException("TRANSFORM without serde does not support " + + s"${dt.getClass.getSimpleName} as output data type") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c55b90b39c872..62babdc6fa094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -689,30 +689,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { recordReader, schemaLess) } else { - - // Decode and input/output format. - type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - def format( fmt: RowFormatContext, configKey: String, defaultConfigValue: String): Format = fmt match { case c: RowFormatDelimitedContext => - // TODO we should use visitRowFormatDelimited function here. However HiveScriptIOSchema - // expects a seq of pairs in which the old parsers' token names are used as keys. - // Transforming the result of visitRowFormatDelimited would be quite a bit messier than - // retrieving the key value pairs ourselves. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq - } - - val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ - entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ - entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) - - (entries, None, Seq.empty, None) + getRowFormatDelimited(c) case c: RowFormatSerdeContext => // Use a serde format. diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index d7ad3d4d73a24..586df6c0cba59 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -57,7 +57,7 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( FROM t ) tmp; --- handle schema less +-- SPARK-32388 handle schema less SELECT TRANSFORM(a) USING 'cat' FROM t; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index c1f723d324d7d..d8fd5ea0f5bd1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -268,7 +268,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-32106: TRANSFORM should respect DATETIME_JAVA8API_ENABLED (no serde)") { assume(TestUtils.testCommandAvailable("python")) Array(false, true).foreach { java8AapiEnable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 0826ecceec931..590c42e4b0d60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -62,7 +62,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with } } - test("TRANSFORM don't support ArrayType/MapType/StructType as output data type (no serde)") { + test("TRANSFORM doesn't support ArrayType/MapType/StructType as output data type (no serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) // check for ArrayType val e1 = intercept[SparkException] { @@ -73,8 +73,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e1.contains("TRANSFORM without serde don't support" + - " ArrayType/MapType/StructType as output data type")) + assert(e1.contains("TRANSFORM without serde does not support" + + " ArrayType as output data type")) // check for MapType val e2 = intercept[SparkException] { @@ -85,8 +85,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e2.contains("TRANSFORM without serde don't support" + - " ArrayType/MapType/StructType as output data type")) + assert(e2.contains("TRANSFORM without serde does not support" + + " MapType as output data type")) // check for StructType val e3 = intercept[SparkException] { @@ -97,7 +97,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e3.contains("TRANSFORM without serde don't support" + - " ArrayType/MapType/StructType as output data type")) + assert(e3.contains("TRANSFORM without serde does not support" + + " StructType as output data type")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 0dddfa39d6378..4c53f17e33b3b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -53,7 +53,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T ) } - private val serdeIOSchema: ScriptTransformationIOSchema = { + private val hiveIOSchema: ScriptTransformationIOSchema = { defaultIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) @@ -71,7 +71,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.collect()) assert(uncaughtExceptionHandler.exception.isEmpty) @@ -89,7 +89,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.collect()) } @@ -110,7 +110,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = serdeIOSchema) + ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e.getMessage.contains("Subprocess exited with status")) @@ -131,7 +131,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T script = "cat", output = Seq(AttributeReference("name", StringType)()), child = child, - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.select("name").collect()) assert(uncaughtExceptionHandler.exception.isEmpty) @@ -148,7 +148,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = serdeIOSchema) + ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e.getMessage.contains("Subprocess exited with status")) @@ -212,7 +212,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T StructField("col1", IntegerType, false), StructField("col2", StringType, true))))()), child = child, - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), df.select('c, 'd, 'e).collect()) } @@ -256,7 +256,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T AttributeReference("a", IntegerType)(), AttributeReference("b", CalendarIntervalType)()), child = df.queryExecution.sparkPlan, - ioschema = serdeIOSchema) + ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) @@ -269,7 +269,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T AttributeReference("a", IntegerType)(), AttributeReference("c", new TestUDT.MyDenseVectorUDT)()), child = df.queryExecution.sparkPlan, - ioschema = serdeIOSchema) + ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e2.getMessage.contains( From d93f7faf438239db5e1e38476a41d7408170949d Mon Sep 17 00:00:00 2001 From: angerszhu Date: Wed, 22 Jul 2020 23:40:46 +0800 Subject: [PATCH 39/42] add UT of row format and fi UT --- .../BaseScriptTransformationExec.scala | 3 +- .../resources/sql-tests/inputs/transform.sql | 57 +++++++++ .../sql-tests/results/transform.sql.out | 40 ++++++- .../BaseScriptTransformationSuite.scala | 110 +++++++++++------- .../spark/sql/hive/HiveInspectors.scala | 3 + .../HiveScriptTransformationSuite.scala | 24 ++-- 6 files changed, 180 insertions(+), 57 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 2fc7deb0858af..12b6934f58c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -188,7 +188,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case StringType => wrapperConvertException(data => data, converter) case BooleanType => wrapperConvertException(data => data.toBoolean, converter) case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => wrapperConvertException(data => data.getBytes, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) case IntegerType => wrapperConvertException(data => data.toInt, converter) case ShortType => wrapperConvertException(data => data.toShort, converter) case LongType => wrapperConvertException(data => data.toLong, converter) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 586df6c0cba59..222be1b836ef3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -87,3 +87,60 @@ MAP a, b USING 'cat' AS (a, b) FROM t; -- transform use REDUCE REDUCE a, b USING 'cat' AS (a, b) FROM t; + +-- transform with defined row format delimit +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (a, b, c, d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t; + + +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '||' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t; + +-- SPARK-31937 transform with defined row format delimit +--SELECT TRANSFORM(a, b, c, d, e, null) +--ROW FORMAT DELIMITED +--FIELDS TERMINATED BY '|' +--COLLECTION ITEMS TERMINATED BY '&' +--MAP KEYS TERMINATED BY '*' +--LINES TERMINATED BY '\n' +--NULL DEFINED AS 'NULL' +--USING 'cat' AS (a, b, c, d, e, f) +--ROW FORMAT DELIMITED +--FIELDS TERMINATED BY '|' +--COLLECTION ITEMS TERMINATED BY '&' +--MAP KEYS TERMINATED BY '*' +--LINES TERMINATED BY '\n' +--NULL DEFINED AS 'NULL' +--FROM VALUEW (1, 1.23, array(1,, 2, 3), map(1, '1'), struct(1, '1')) t(a, b, c, d, e); +-- +--SELECT TRANSFORM(a, b, c, d, e, null) +--ROW FORMAT DELIMITED +--FIELDS TERMINATED BY '|' +--COLLECTION ITEMS TERMINATED BY '&' +--MAP KEYS TERMINATED BY '*' +--LINES TERMINATED BY '\n' +--NULL DEFINED AS 'NULL' +--USING 'cat' AS (a) +--ROW FORMAT DELIMITED +--FIELDS TERMINATED BY '||' +--LINES TERMINATED BY '\n' +--NULL DEFINED AS 'NULL' +--FROM VALUEW (1, 1.23, array(1,, 2, 3), map(1, '1'), struct(1, '1')) t(a, b, c, d, e); diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 37214b90b86fd..744d6384f9c45 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query @@ -184,3 +184,41 @@ struct 1 true 2 false 3 true + + +-- !query +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (a, b, c, d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t +-- !query schema +struct +-- !query output +1 | true | +2 | false | + + +-- !query +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '||' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t +-- !query schema +struct +-- !query output +1 +2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index d8fd5ea0f5bd1..101c9a5c899db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -197,23 +197,68 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-32106: TRANSFORM should support all data types as input (no serde)") { + def testBasicInputDataTypesWith(serde: ScriptTransformationIOSchema, testName: String): Unit = { + test(s"SPARK-32106: TRANSFORM should support basic data types as input ($testName)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), true), + (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), true), + (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), false) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") + .withColumn("j", lit("abc").cast("binary")) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", FloatType)(), + AttributeReference("d", DoubleType)(), + AttributeReference("e", ByteType)(), + AttributeReference("f", DecimalType(38, 18))(), + AttributeReference("g", TimestampType)(), + AttributeReference("h", DateType)(), + AttributeReference("i", BooleanType)(), + AttributeReference("j", BinaryType)()), + child = child, + ioschema = serde + ), + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect()) + } + } + } + + testBasicInputDataTypesWith(defaultIOSchema, "no serde") + + test("SPARK-32106: TRANSFORM should support more data types (interval, array, map, struct " + + "and udt) as input (no serde)") { assume(TestUtils.testCommandAvailable("python")) withTempView("v") { val df = Seq( - (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), - new Date(2020, 7, 1), new CalendarInterval(7, 1, 1000), Array(0, 1, 2), - Map("a" -> 1), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), - new Date(2020, 7, 2), new CalendarInterval(7, 2, 2000), Array(3, 4, 5), - Map("b" -> 2), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)), - (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), - new Date(2020, 7, 3), new CalendarInterval(7, 3, 3000), Array(6, 7, 8), - Map("c" -> 3), new TestUDT.MyDenseVector(Array(1, 2, 3)), new SimpleTuple(1, 1L)) - ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m") - .select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, - struct('a, 'b).as("n"), unhex('a).as("o"), lit(true).as("p") - ) // Note column d's data type is Decimal(38, 18) + (new CalendarInterval(7, 1, 1000), Array(0, 1, 2), Map("a" -> 1), (1, 2), + new SimpleTuple(1, 1L)), + (new CalendarInterval(7, 2, 2000), Array(3, 4, 5), Map("b" -> 2), (3, 4), + new SimpleTuple(1, 1L)), + (new CalendarInterval(7, 3, 3000), Array(6, 7, 8), Map("c" -> 3), (5, 6), + new SimpleTuple(1, 1L)) + ).toDF("a", "b", "c", "d", "e") // Can't support convert script output data to ArrayType/MapType/StructType now, // return these column still as string. @@ -228,43 +273,18 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU df.col("b").expr, df.col("c").expr, df.col("d").expr, - df.col("e").expr, - df.col("f").expr, - df.col("g").expr, - df.col("h").expr, - df.col("i").expr, - df.col("j").expr, - df.col("k").expr, - df.col("l").expr, - df.col("m").expr, - df.col("n").expr, - df.col("o").expr, - df.col("p").expr), + df.col("e").expr), script = "cat", output = Seq( - AttributeReference("a", IntegerType)(), + AttributeReference("a", CalendarIntervalType)(), AttributeReference("b", StringType)(), - AttributeReference("c", FloatType)(), - AttributeReference("d", DoubleType)(), - AttributeReference("e", ByteType)(), - AttributeReference("f", DecimalType(38, 18))(), - AttributeReference("g", TimestampType)(), - AttributeReference("h", DateType)(), - AttributeReference("i", CalendarIntervalType)(), - AttributeReference("j", StringType)(), - AttributeReference("k", StringType)(), - AttributeReference("l", StringType)(), - AttributeReference("m", new SimpleTupleUDT)(), - AttributeReference("n", StringType)(), - AttributeReference("o", BinaryType)(), - AttributeReference("p", BooleanType)()), + AttributeReference("c", StringType)(), + AttributeReference("d", StringType)(), + AttributeReference("e", new SimpleTupleUDT)()), child = child, ioschema = defaultIOSchema ), - df.select( - 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, - 'j.cast("string"), 'k.cast("string"), - 'l.cast("string"), 'm, 'n.cast("string"), 'o, 'p).collect()) + df.select('a, 'b.cast("string"), 'c.cast("string"), 'd.cast("string"), 'e).collect()) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 16e9014340244..060ab6a71af9e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1063,6 +1063,9 @@ private[hive] trait HiveInspectors { case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo + case dt => + throw new AnalysisException("TRANSFORM with hive serde does not support " + + s"${dt.getClass.getSimpleName.replace("$", "")} as input data type") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 4c53f17e33b3b..078ff70cb3150 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -185,6 +185,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T } } + testBasicInputDataTypesWith(hiveIOSchema, "hive serde") + test("SPARK-32106: TRANSFORM supports complex data types type (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) withTempView("v") { @@ -258,8 +260,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = df.queryExecution.sparkPlan, ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) - } - assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) + }.getMessage + assert(e1.contains( + "TRANSFORM with hive serde does not support CalendarIntervalType as input data type")) val e2 = intercept[SparkException] { val plan = createScriptTransformationExec( @@ -271,9 +274,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = df.queryExecution.sparkPlan, ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) - } - assert(e2.getMessage.contains( - "scala.MatchError: org.apache.spark.sql.types.TestUDT$MyDenseVectorUDT")) + }.getMessage + assert(e2.contains( + "TRANSFORM with hive serde does not support MyDenseVectorUDT as input data type")) } } @@ -293,8 +296,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T |SELECT TRANSFORM(a, b) USING 'cat' AS (a, b) |FROM v """.stripMargin).collect() - } - assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType")) + }.getMessage + assert(e1.contains( + "TRANSFORM with hive serde does not support CalendarIntervalType as input data type")) val e2 = intercept[SparkException] { sql( @@ -302,9 +306,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T |SELECT TRANSFORM(a, c) USING 'cat' AS (a, c) |FROM v """.stripMargin).collect() - } - assert(e2.getMessage.contains( - "scala.MatchError: org.apache.spark.sql.types.TestUDT$MyDenseVectorUDT")) + }.getMessage + assert(e2.contains( + "TRANSFORM with hive serde does not support MyDenseVectorUDT as input data type")) } } } From be80c27557f614074ac50e2674180a778c0208ab Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 23 Jul 2020 10:17:28 +0800 Subject: [PATCH 40/42] address comment --- .../sql/catalyst/parser/AstBuilder.scala | 9 +- .../sql/catalyst/parser/PlanParserSuite.scala | 21 ++- .../spark/sql/execution/SparkSqlParser.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../resources/sql-tests/inputs/transform.sql | 32 ---- .../SparkScriptTransformationSuite.scala | 3 +- .../spark/sql/hive/HiveInspectors.scala | 4 +- .../HiveScriptTransformationExec.scala | 178 +++++++++--------- .../HiveScriptTransformationSuite.scala | 10 +- 9 files changed, 129 insertions(+), 133 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3feea66e62a62..891e9860c73cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -744,10 +744,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging selectClause.hints.asScala.foldRight(withWindow)(withHints) } - // Decode and input/output format. - type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + // Script Transform's input/output format. + type ScriptIOFormat = + (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): Format = { + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = { // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. // Transforming the result of visitRowFormatDelimited would be quite a bit messier than @@ -776,7 +777,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { - def format(fmt: RowFormatContext): Format = fmt match { + def format(fmt: RowFormatContext): ScriptIOFormat = fmt match { case c: RowFormatDelimitedContext => getRowFormatDelimited(c) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8d44a33d882a6..db665446631a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1032,7 +1032,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b)) } - test("SPARK-32106: TRANSFORM without serde") { + test("SPARK-32106: TRANSFORM plan") { // verify schema less assertEqual( """ @@ -1122,5 +1122,24 @@ class PlanParserSuite extends AnalysisTest { ("TOK_TABLEROWFORMATLINES", "'\\n'"), ("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None, List.empty, List.empty, None, None, false))) + + // verify ROW FORMAT SERDE + intercept( + """ + |SELECT TRANSFORM(a, b, c) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES( + | "separatorChar" = "\t", + | "quoteChar" = "'", + | "escapeChar" = "\\") + |USING 'cat' AS (a, b, c) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES( + | "separatorChar" = "\t", + | "quoteChar" = "'", + | "escapeChar" = "\\") + |FROM testData + """.stripMargin, + "TRANSFORM with serde is only supported in hive mode") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 62babdc6fa094..7ef46c949db6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -692,7 +692,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { def format( fmt: RowFormatContext, configKey: String, - defaultConfigValue: String): Format = fmt match { + defaultConfigValue: String): ScriptIOFormat = fmt match { case c: RowFormatDelimitedContext => getRowFormatDelimited(c) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 21ddea51df4a6..1e0d0c346731a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -534,8 +534,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SparkScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.ScriptTransformation(input, script, output, child, ioschema) - if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty => + case logical.ScriptTransformation(input, script, output, child, ioschema) => SparkScriptTransformationExec( input, script, diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index 222be1b836ef3..8610e384d6fab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -101,7 +101,6 @@ LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' FROM t; - SELECT TRANSFORM(a, b, c, null) ROW FORMAT DELIMITED FIELDS TERMINATED BY '|' @@ -113,34 +112,3 @@ FIELDS TERMINATED BY '||' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' FROM t; - --- SPARK-31937 transform with defined row format delimit ---SELECT TRANSFORM(a, b, c, d, e, null) ---ROW FORMAT DELIMITED ---FIELDS TERMINATED BY '|' ---COLLECTION ITEMS TERMINATED BY '&' ---MAP KEYS TERMINATED BY '*' ---LINES TERMINATED BY '\n' ---NULL DEFINED AS 'NULL' ---USING 'cat' AS (a, b, c, d, e, f) ---ROW FORMAT DELIMITED ---FIELDS TERMINATED BY '|' ---COLLECTION ITEMS TERMINATED BY '&' ---MAP KEYS TERMINATED BY '*' ---LINES TERMINATED BY '\n' ---NULL DEFINED AS 'NULL' ---FROM VALUEW (1, 1.23, array(1,, 2, 3), map(1, '1'), struct(1, '1')) t(a, b, c, d, e); --- ---SELECT TRANSFORM(a, b, c, d, e, null) ---ROW FORMAT DELIMITED ---FIELDS TERMINATED BY '|' ---COLLECTION ITEMS TERMINATED BY '&' ---MAP KEYS TERMINATED BY '*' ---LINES TERMINATED BY '\n' ---NULL DEFINED AS 'NULL' ---USING 'cat' AS (a) ---ROW FORMAT DELIMITED ---FIELDS TERMINATED BY '||' ---LINES TERMINATED BY '\n' ---NULL DEFINED AS 'NULL' ---FROM VALUEW (1, 1.23, array(1,, 2, 3), map(1, '1'), struct(1, '1')) t(a, b, c, d, e); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 590c42e4b0d60..d85aa6cbe3a17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -62,7 +62,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with } } - test("TRANSFORM doesn't support ArrayType/MapType/StructType as output data type (no serde)") { + test("SPARK-32106: TRANSFORM doesn't support ArrayType/MapType/StructType " + + "as output data type (no serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) // check for ArrayType val e1 = intercept[SparkException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 060ab6a71af9e..c09e0ce095e3b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1064,8 +1064,8 @@ private[hive] trait HiveInspectors { case TimestampType => timestampTypeInfo case NullType => voidTypeInfo case dt => - throw new AnalysisException("TRANSFORM with hive serde does not support " + - s"${dt.getClass.getSimpleName.replace("$", "")} as input data type") + throw new AnalysisException("HiveInspectors does not support convert " + + s"${dt.getClass.getSimpleName.replace("$", "")} to Hive TypeInfo") } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 69b5b493394be..535eae5e47adb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, Utils} /** @@ -53,84 +53,8 @@ case class HiveScriptTransformationExec( output: Seq[Attribute], child: SparkPlan, ioschema: ScriptTransformationIOSchema) - extends BaseScriptTransformationExec with HiveInspectors { - - private def initInputSerDe( - input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { - ioschema.inputSerdeClass.map { serdeClass => - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) - val fieldObjectInspectors = columnTypes.map(toInspector) - val objectInspector = ObjectInspectorFactory - .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) - (serde, objectInspector) - } - } - - private def initOutputSerDe( - output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { - ioschema.outputSerdeClass.map { serdeClass => - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) - val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] - (serde, structObjectInspector) - } - } - - private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") - val columnTypes = attrs.map(_.dataType) - (columns, columnTypes) - } - - private def initSerDe( - serdeClassName: String, - columns: Seq[String], - columnTypes: Seq[DataType], - serdeProps: Seq[(String, String)]): AbstractSerDe = { - - val serde = Utils.classForName[AbstractSerDe](serdeClassName).getConstructor(). - newInstance() - - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - - var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - - val properties = new Properties() - // Can not use properties.putAll(propsMap.asJava) in scala-2.12 - // See https://github.com/scala/bug/issues/10418 - propsMap.foreach { case (k, v) => properties.put(k, v) } - serde.initialize(null, properties) - - serde - } - - private def recordReader( - inputStream: InputStream, - conf: Configuration): Option[RecordReader] = { - ioschema.recordReaderClass.map { klass => - val instance = Utils.classForName[RecordReader](klass).getConstructor(). - newInstance() - val props = new Properties() - // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 - // See https://github.com/scala/bug/issues/10418 - ioschema.outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } - instance.initialize(inputStream, conf, props) - instance - } - } - - private def recordWriter( - outputStream: OutputStream, - conf: Configuration): Option[RecordWriter] = { - ioschema.recordWriterClass.map { klass => - val instance = Utils.classForName[RecordWriter](klass).getConstructor(). - newInstance() - instance.initialize(outputStream, conf) - instance - } - } + extends BaseScriptTransformationExec { + import HiveScriptIOSchema._ private def createOutputIteratorWithSerde( writerThread: BaseScriptTransformationWriterThread, @@ -144,7 +68,8 @@ case class HiveScriptTransformationExec( var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) - @Nullable val scriptOutputReader = recordReader(scriptOutputStream, hadoopConf).orNull + @Nullable val scriptOutputReader = + recordReader(ioschema, scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null val reusedWritableObject = outputSerde.getSerializedClass.getConstructor().newInstance() @@ -218,7 +143,7 @@ case class HiveScriptTransformationExec( // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop - @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) + @Nullable val (inputSerde, inputSoi) = initInputSerDe(ioschema, input).getOrElse((null, null)) // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null // We will use StringBuffer to pass data, in this case, we should cast data as string too. @@ -239,7 +164,6 @@ case class HiveScriptTransformationExec( inputSoi, ioschema, outputStream, - recordWriter, proc, stderrBuffer, TaskContext.get(), @@ -249,7 +173,7 @@ case class HiveScriptTransformationExec( // This nullability is a performance optimization in order to avoid an Option.foreach() call // inside of a loop @Nullable val (outputSerde, outputSoi) = { - initOutputSerDe(output).getOrElse((null, null)) + initOutputSerDe(ioschema, output).getOrElse((null, null)) } val outputIterator = if (outputSerde == null) { @@ -272,16 +196,16 @@ case class HiveScriptTransformationWriterThread( @Nullable inputSoi: StructObjectInspector, ioSchema: ScriptTransformationIOSchema, outputStream: OutputStream, - recordWriter: (OutputStream, Configuration) => Option[RecordWriter], proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration) extends BaseScriptTransformationWriterThread with HiveInspectors { + import HiveScriptIOSchema._ override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) - @Nullable val scriptInputWriter = recordWriter(dataOutputStream, conf).orNull + @Nullable val scriptInputWriter = recordWriter(ioSchema, dataOutputStream, conf).orNull if (inputSerde == null) { processRowsWithoutSerde() @@ -308,3 +232,87 @@ case class HiveScriptTransformationWriterThread( } } } + +object HiveScriptIOSchema extends HiveInspectors { + + def initInputSerDe( + ioschema: ScriptTransformationIOSchema, + input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) + (serde, objectInspector) + } + } + + def initOutputSerDe( + ioschema: ScriptTransformationIOSchema, + output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } + } + + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { + val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}") + val columnTypes = attrs.map(_.dataType) + (columns, columnTypes) + } + + def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { + + val serde = Utils.classForName[AbstractSerDe](serdeClassName).getConstructor(). + newInstance() + + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + + var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + + val properties = new Properties() + // Can not use properties.putAll(propsMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + propsMap.foreach { case (k, v) => properties.put(k, v) } + serde.initialize(null, properties) + + serde + } + + def recordReader( + ioschema: ScriptTransformationIOSchema, + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + ioschema.recordReaderClass.map { klass => + val instance = Utils.classForName[RecordReader](klass).getConstructor(). + newInstance() + val props = new Properties() + // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 + // See https://github.com/scala/bug/issues/10418 + ioschema.outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter( + ioschema: ScriptTransformationIOSchema, + outputStream: OutputStream, + conf: Configuration): Option[RecordWriter] = { + ioschema.recordWriterClass.map { klass => + val instance = Utils.classForName[RecordWriter](klass).getConstructor(). + newInstance() + instance.initialize(outputStream, conf) + instance + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 078ff70cb3150..5f08398a99450 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -220,7 +220,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T } } - test("SPARK-32106: TRANSFORM supports complex data types end to end (hive serde) ") { + test("SPARK-32106: TRANSFORM supports complex data types end to end (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) withTempView("v") { val df = Seq( @@ -262,7 +262,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T SparkPlanTest.executePlan(plan, hiveContext) }.getMessage assert(e1.contains( - "TRANSFORM with hive serde does not support CalendarIntervalType as input data type")) + "HiveInspectors does not support convert CalendarIntervalType to Hive TypeInfo")) val e2 = intercept[SparkException] { val plan = createScriptTransformationExec( @@ -276,7 +276,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T SparkPlanTest.executePlan(plan, hiveContext) }.getMessage assert(e2.contains( - "TRANSFORM with hive serde does not support MyDenseVectorUDT as input data type")) + "HiveInspectors does not support convert MyDenseVectorUDT to Hive TypeInfo")) } } @@ -298,7 +298,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin).collect() }.getMessage assert(e1.contains( - "TRANSFORM with hive serde does not support CalendarIntervalType as input data type")) + "HiveInspectors does not support convert CalendarIntervalType to Hive TypeInfo")) val e2 = intercept[SparkException] { sql( @@ -308,7 +308,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin).collect() }.getMessage assert(e2.contains( - "TRANSFORM with hive serde does not support MyDenseVectorUDT as input data type")) + "HiveInspectors does not support convert MyDenseVectorUDT to Hive TypeInfo")) } } } From 7f3cff81d6238d8de6ea2c57dc94eed5abe1ff75 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 23 Jul 2020 10:22:08 +0800 Subject: [PATCH 41/42] Update PlanParserSuite.scala --- .../org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index db665446631a3..e4790e2dfa634 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1123,7 +1123,7 @@ class PlanParserSuite extends AnalysisTest { ("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None, List.empty, List.empty, None, None, false))) - // verify ROW FORMAT SERDE + // verify with ROW FORMAT SERDE intercept( """ |SELECT TRANSFORM(a, b, c) From 03d3409f6ca641a4f10fdc2ac71479445220f676 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 23 Jul 2020 11:08:27 +0800 Subject: [PATCH 42/42] address comment --- .../sql/execution/BaseScriptTransformationExec.scala | 2 +- .../sql/execution/SparkScriptTransformationSuite.scala | 6 +++--- .../scala/org/apache/spark/sql/hive/HiveInspectors.scala | 4 ++-- .../hive/execution/HiveScriptTransformationSuite.scala | 8 ++++---- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 12b6934f58c8f..7760a3797eb49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -220,7 +220,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt => - throw new SparkException("TRANSFORM without serde does not support " + + throw new SparkException(s"${nodeName} without serde does not support " + s"${dt.getClass.getSimpleName} as output data type") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index d85aa6cbe3a17..6b20f4cf88645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -74,7 +74,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e1.contains("TRANSFORM without serde does not support" + + assert(e1.contains("SparkScriptTransformation without serde does not support" + " ArrayType as output data type")) // check for MapType @@ -86,7 +86,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e2.contains("TRANSFORM without serde does not support" + + assert(e2.contains("SparkScriptTransformation without serde does not support" + " MapType as output data type")) // check for StructType @@ -98,7 +98,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) """.stripMargin).collect() }.getMessage - assert(e3.contains("TRANSFORM without serde does not support" + + assert(e3.contains("SparkScriptTransformation without serde does not support" + " StructType as output data type")) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c09e0ce095e3b..d075b69d976cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1064,8 +1064,8 @@ private[hive] trait HiveInspectors { case TimestampType => timestampTypeInfo case NullType => voidTypeInfo case dt => - throw new AnalysisException("HiveInspectors does not support convert " + - s"${dt.getClass.getSimpleName.replace("$", "")} to Hive TypeInfo") + throw new AnalysisException( + s"${dt.getClass.getSimpleName.replace("$", "")} cannot be converted to Hive TypeInfo") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 5f08398a99450..e89e20c2c723e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -262,7 +262,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T SparkPlanTest.executePlan(plan, hiveContext) }.getMessage assert(e1.contains( - "HiveInspectors does not support convert CalendarIntervalType to Hive TypeInfo")) + "CalendarIntervalType cannot be converted to Hive TypeInfo")) val e2 = intercept[SparkException] { val plan = createScriptTransformationExec( @@ -276,7 +276,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T SparkPlanTest.executePlan(plan, hiveContext) }.getMessage assert(e2.contains( - "HiveInspectors does not support convert MyDenseVectorUDT to Hive TypeInfo")) + "MyDenseVectorUDT cannot be converted to Hive TypeInfo")) } } @@ -298,7 +298,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin).collect() }.getMessage assert(e1.contains( - "HiveInspectors does not support convert CalendarIntervalType to Hive TypeInfo")) + "CalendarIntervalType cannot be converted to Hive TypeInfo")) val e2 = intercept[SparkException] { sql( @@ -308,7 +308,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin).collect() }.getMessage assert(e2.contains( - "HiveInspectors does not support convert MyDenseVectorUDT to Hive TypeInfo")) + "MyDenseVectorUDT cannot be converted to Hive TypeInfo")) } } }