Skip to content

Commit

Permalink
[SPARK-12188][SQL] Code refactoring and comment correction in Dataset…
Browse files Browse the repository at this point in the history
… APIs

This PR contains the following updates:

- Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`.
- Replaced all the `queryExecution.analyzed` by the function call `logicalPlan`
- A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`)
- A few API descriptions are wrong. (e.g., `mapPartitions`)

marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you!

Author: gatorsmile <[email protected]>

Closes apache#10184 from gatorsmile/datasetClean.
  • Loading branch information
gatorsmile authored and marmbrus committed Dec 8, 2015
1 parent c0b13d5 commit 5d96a71
Showing 1 changed file with 40 additions and 40 deletions.
80 changes: 40 additions & 40 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,21 @@ class Dataset[T] private[sql](
tEncoder: Encoder[T]) extends Queryable with Serializable {

/**
* An unresolved version of the internal encoder for the type of this dataset. This one is marked
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
* object type (that will be possibly resolved to a different schema).
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
* marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
* same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)

/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)

/**
* The encoder where the expressions used to construct an object from an input row have been
* bound to the ordinals of the given schema.
*/
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)

private implicit def classTag = resolvedTEncoder.clsTag

Expand All @@ -89,7 +95,7 @@ class Dataset[T] private[sql](
override def schema: StructType = resolvedTEncoder.schema

/**
* Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
* Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
Expand All @@ -111,7 +117,7 @@ class Dataset[T] private[sql](
* ************* */

/**
* Returns a new `Dataset` where each record has been mapped on to the specified type. The
* Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`)
Expand Down Expand Up @@ -145,23 +151,20 @@ class Dataset[T] private[sql](
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)

/**
* Returns this Dataset.
* Returns this [[Dataset]].
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
def toDS(): Dataset[T] = this

/**
* Converts this Dataset to an RDD.
* Converts this [[Dataset]] to an [[RDD]].
* @since 1.6.0
*/
def rdd: RDD[T] = {
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
iter.map(bound.fromRow)
iter.map(boundTEncoder.fromRow)
}
}

Expand Down Expand Up @@ -189,15 +192,15 @@ class Dataset[T] private[sql](
def show(numRows: Int): Unit = show(numRows, truncate = true)

/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
* Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @since 1.6.0
*/
def show(): Unit = show(20)

/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form.
* Displays the top 20 rows of [[Dataset]] in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
Expand All @@ -207,7 +210,7 @@ class Dataset[T] private[sql](
def show(truncate: Boolean): Unit = show(20, truncate)

/**
* Displays the [[DataFrame]] in a tabular form. For example:
* Displays the [[Dataset]] in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
Expand Down Expand Up @@ -291,7 +294,7 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
Expand All @@ -307,7 +310,7 @@ class Dataset[T] private[sql](

/**
* (Java-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
Expand Down Expand Up @@ -341,28 +344,28 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)

/**
* (Java-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))

/**
* (Scala-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)

/**
* (Java-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
Expand All @@ -374,27 +377,27 @@ class Dataset[T] private[sql](

/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: (T, T) => T): T = rdd.reduce(func)

/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))

/**
* (Scala-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)

Expand Down Expand Up @@ -429,18 +432,18 @@ class Dataset[T] private[sql](

/**
* (Java-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(func.call(_))(encoder)

/* ****************** *
* Typed Relational *
* ****************** */

/**
* Selects a set of column based expressions.
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
Expand All @@ -464,8 +467,8 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
boundTEncoder,
logicalPlan.output).named :: Nil,
logicalPlan))
}

Expand All @@ -477,7 +480,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))

new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
Expand Down Expand Up @@ -654,25 +657,22 @@ class Dataset[T] private[sql](
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
val bound = tEnc.bind(input)
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
}

/**
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
Expand All @@ -683,7 +683,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
Expand All @@ -692,7 +692,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
Expand Down

0 comments on commit 5d96a71

Please sign in to comment.