diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 895eeedd86b8b..b96a861196897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -46,6 +46,7 @@ class SparkPlanner( Window :: JoinSelection :: InMemoryScans :: + SparkScripts:: BasicOperators :: Nil) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala new file mode 100644 index 0000000000000..c6bbbd140c4ea --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.{CircularBuffer, RedirectThread} + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class SparkScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: SparkScriptIOSchema) + extends BaseScriptTransformationExec { + + override def processIterator(inputIterator: Iterator[InternalRow], hadoopConf: Configuration) + : Iterator[InternalRow] = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator.map(outputProjection), + input.map(_.dataType), + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { + var curLine: String = null + val mutableRow = new SpecificInternalRow(output.map(_.dataType)) + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + if (!ioschema.schemaLess) { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) + } else { + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) + } + } + } + + writerThread.start() + + outputIterator + } +} + +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + ioSchema: SparkScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration) + extends BaseScriptTransformationWriterThread( + iter, + inputSchema, + ioSchema, + outputStream, + proc, + stderrBuffer, + taskContext, + conf) { + + setDaemon(true) + + override def processRows(): Unit = { + processRowsWithoutSerde() + } +} + +object SparkScriptIOSchema { + def apply(input: ScriptInputOutputSchema): SparkScriptIOSchema = { + SparkScriptIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } +} + +/** + * The wrapper class of Spark script transformation input and output schema properties + */ +case class SparkScriptIOSchema ( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) extends BaseScriptTransformIOSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 583e5a2c5c57e..d5366de2ea704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,6 +532,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object SparkScripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child, ioschema) + if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty => + SparkScriptTransformationExec( + input, + script, + output, + planLater(child), + SparkScriptIOSchema(ioschema) + ) :: Nil + case _ => Nil + } + } + /** * This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`. * It won't affect the execution, because `StreamingRelation` will be replaced with diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala similarity index 99% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 96fe646d39fde..098ffd3b7d75a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -275,7 +275,7 @@ object HiveScriptIOSchema { } /** - * The wrapper class of Hive input and output schema properties + * The wrapper class of Hive script transformation input and output schema properties */ case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala new file mode 100644 index 0000000000000..4869872273832 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.sql.Timestamp +import java.util.Locale + +import org.scalatest.Assertions._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, TaskContext, TestUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StringType + +abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils + with TestHiveSingleton with BeforeAndAfterEach { + + def scriptType: String + + import spark.implicits._ + + var noSerdeIOSchema: BaseScriptTransformIOSchema = _ + + private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ + + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler + + protected override def beforeAll(): Unit = { + super.beforeAll() + defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) + } + + protected override def afterAll(): Unit = { + super.afterAll() + Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) + } + + override protected def afterEach(): Unit = { + super.afterEach() + uncaughtExceptionHandler.cleanStatus() + } + + def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: BaseScriptTransformIOSchema): BaseScriptTransformationExec = { + scriptType.toUpperCase(Locale.ROOT) match { + case "SPARK" => new SparkScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema.asInstanceOf[SparkScriptIOSchema] + ) + case "HIVE" => new HiveScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema.asInstanceOf[HiveScriptIOSchema] + ) + case _ => throw new TestFailedException( + "Test class implement from BaseScriptTransformationSuite" + + " should override method `scriptType` to Spark or Hive", 0) + } + } + + test("cat without SerDe") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + ), + rowsDf.collect()) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + ), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("SPARK-25990: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |FROM v + """.stripMargin) + + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (HiveUtils.isHive23) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + checkAnswer(query, identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + } + } + + test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = noSerdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } +} + +case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 35252fc47f49f..3caeca4a0eb30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -17,30 +17,20 @@ package org.apache.spark.sql.hive.execution -import java.sql.Timestamp - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.scalatest.Assertions._ -import org.scalatest.BeforeAndAfterEach import org.scalatest.exceptions.TestFailedException -import org.apache.spark.{SparkException, TaskContext, TestUtils} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{BaseScriptTransformIOSchema, SparkPlan, SparkPlanTest} import org.apache.spark.sql.types.StringType -class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton - with BeforeAndAfterEach { +class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { + override def scriptType: String = "HIVE" + import spark.implicits._ - private val noSerdeIOSchema = HiveScriptIOSchema( + noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, outputRowFormat = Seq.empty, inputSerdeClass = None, @@ -52,46 +42,11 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with schemaLess = false ) - private val serdeIOSchema = noSerdeIOSchema.copy( - inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), - outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) - ) - - private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ - - private val uncaughtExceptionHandler = new TestUncaughtExceptionHandler - - protected override def beforeAll(): Unit = { - super.beforeAll() - defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler - Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) - } - - protected override def afterAll(): Unit = { - super.afterAll() - Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) - } - - override protected def afterEach(): Unit = { - super.afterEach() - uncaughtExceptionHandler.cleanStatus() - } - - test("cat without SerDe") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = child, - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - assert(uncaughtExceptionHandler.exception.isEmpty) + private val serdeIOSchema: BaseScriptTransformIOSchema = { + noSerdeIOSchema.asInstanceOf[HiveScriptIOSchema].copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) } test("cat with LazySimpleSerDe") { @@ -100,7 +55,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -111,27 +66,6 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with assert(uncaughtExceptionHandler.exception.isEmpty) } - test("script transformation should not swallow errors from upstream operators (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[TestFailedException] { - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - } - assert(e.getMessage().contains("intentional exception")) - // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException - assert(uncaughtExceptionHandler.exception.isEmpty) - } - test("script transformation should not swallow errors from upstream operators (with serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) @@ -139,7 +73,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), @@ -160,7 +94,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val e = intercept[SparkException] { val plan = - new HiveScriptTransformationExec( + createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -181,7 +115,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("name").expr), script = "cat", output = Seq(AttributeReference("name", StringType)()), @@ -192,67 +126,13 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-25990: TRANSFORM should handle different data types correctly") { - assume(TestUtils.testCommandAvailable("python")) - val scriptFilePath = getTestResourcePath("test_script.py") - - withTempView("v") { - val df = Seq( - (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), - (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), - (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) - ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) - df.createTempView("v") - - val query = sql( - s""" - |SELECT - |TRANSFORM(a, b, c, d, e) - |USING 'python $scriptFilePath' AS (a, b, c, d, e) - |FROM v - """.stripMargin) - - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (HiveUtils.isHive23) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } - checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - decimalToString('d), - 'e.cast("string")).collect()) - } - } - - test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[SparkException] { - val plan = - new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "some_non_existent_command", - output = Seq(AttributeReference("a", StringType)()), - child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) - } - assert(e.getMessage.contains("Subprocess exited with status")) - assert(uncaughtExceptionHandler.exception.isEmpty) - } - test("SPARK-30973: TRANSFORM should wait for the termination of the script (with serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new HiveScriptTransformationExec( + createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), @@ -265,16 +145,3 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with } } -private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = { - child.execute().map { x => - assert(TaskContext.get() != null) // Make sure that TaskContext is defined. - Thread.sleep(1000) // This sleep gives the external process time to start. - throw new IllegalArgumentException("intentional exception") - } - } - - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = child.outputPartitioning -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala new file mode 100644 index 0000000000000..372e2a8054cda --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.execution.SparkScriptIOSchema + +class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { + + override def scriptType: String = "SPARK" + + noSerdeIOSchema = SparkScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, + schemaLess = false + ) +}