Skip to content

Commit

Permalink
Add pipe to Dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 22, 2021
1 parent 116f4ca commit d4f9457
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 4 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private[spark] class PipedRDD[T: ClassTag](
}
}

private object PipedRDD {
object PipedRDD {
// Split a string into words using a standard StringTokenizer
def tokenize(command: String): Seq[String] = {
val buf = new ArrayBuffer[String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -585,3 +585,29 @@ case class CoGroup(
outputObjAttr: Attribute,
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode with ObjectProducer

object PipeElements {
def apply[T : Encoder](
command: String,
child: LogicalPlan): LogicalPlan = {
val deserialized = CatalystSerde.deserialize[T](child)
implicit val encoder = Encoders.STRING
val piped = PipeElements(
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[String],
command,
deserialized)
CatalystSerde.serialize[String](piped)
}
}

/**
* A relation produced by piping elements to a forked external process.
*/
case class PipeElements[T](
argumentClass: Class[_],
argumentSchema: StructType,
outputObjAttr: Attribute,
command: String,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,24 @@ class Dataset[T] private[sql](
flatMap(func)(encoder)
}

/**
* Return a new Dataset of string created by piping elements to a forked external process.
* The resulting Dataset is computed by executing the given process once per partition.
* All elements of each input partition are written to a process's stdin as lines of input
* separated by a newline. The resulting partition consists of the process's stdout output, with
* each line of stdout resulting in one element of the output partition. A process is invoked
* even for empty partitions.
*
* @param command command to run in forked process.
*
* @group typedrel
* @since 3.2.0
*/
def pipe(command: String): Dataset[String] = {
implicit val stringEncoder = Encoders.STRING
withTypedPlan[String](PipeElements[T](command, logicalPlan))
}

/**
* Applies a function `f` to all rows.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
case logical.PipeElements(_, _, objAttr, command, child) =>
execution.PipeElementsExec(objAttr, command, planLater(child)) :: Nil
case logical.AppendColumns(f, _, _, in, out, child) =>
execution.AppendColumnsExec(f, in, out, planLater(child)) :: Nil
case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.language.existentials
import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.api.r._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.{PipedRDD, RDD}
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.r.SQLUtils._
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -624,3 +624,32 @@ case class CoGroupExec(
}
}
}

/**
* Piping elements to a forked external process.
* The output of its child must be a single-field row containing the input object.
*/
case class PipeElementsExec(
outputObjAttr: Attribute,
command: String,
child: SparkPlan)
extends ObjectConsumerExec with ObjectProducerExec {

override protected def doExecute(): RDD[InternalRow] = {
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
val printRDDElement: (InternalRow, String => Unit) => Unit = (row, printFunc) => {
printFunc(getObject(row).toString)
}

child.execute()
.pipe(command = PipedRDD.tokenize(command), printRDDElement = printRDDElement)
.mapPartitionsInternal { iter =>
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
iter.map(ele => outputObject(ele))
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning
}
16 changes: 15 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.Assertions._
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.{SparkException, TaskContext, TestUtils}
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
Expand Down Expand Up @@ -2007,6 +2007,20 @@ class DatasetSuite extends QueryTest

checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil)
}

test("Pipe Dataset") {
assume(TestUtils.testCommandAvailable("cat"))

val nums = spark.range(4)
val piped = nums.pipe("cat").toDF

checkAnswer(piped, Row("0") :: Row("1") :: Row("2") :: Row("3") :: Nil)

val piped2 = nums.pipe("wc -l").toDF.collect()
assert(piped2.size == 2)
assert(piped2(0).getString(0).trim == "2")
assert(piped2(1).getString(0).trim == "2")
}
}

case class Bar(a: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,20 @@ class StreamSuite extends StreamTest {
}
}
}

test("Pipe Streaming Dataset") {
assume(TestUtils.testCommandAvailable("cat"))

val inputData = MemoryStream[Int]
val piped = inputData.toDS()
.pipe("cat").toDF

testStream(piped)(
AddData(inputData, 1, 2, 3),
CheckAnswer(Row("1"), Row("2"), Row("3")),
AddData(inputData, 4),
CheckAnswer(Row("1"), Row("2"), Row("3"), Row("4")))
}
}

abstract class FakeSource extends StreamSourceProvider {
Expand Down

0 comments on commit d4f9457

Please sign in to comment.