Skip to content

Commit

Permalink
[SPARK-7324][SQL] DataFrame.dropDuplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed May 11, 2015
1 parent 91dc3df commit 130692f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 4 deletions.
36 changes: 34 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,6 @@ def groupBy(self, *cols):
jdf = self._jdf.groupBy(self._jcols(*cols))
return GroupedData(jdf, self.sql_ctx)

groupby = groupBy

def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy.agg()``).
Expand Down Expand Up @@ -793,6 +791,36 @@ def subtract(self, other):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)

def dropDuplicates(self, subset=None):
"""Return a new :class:`DataFrame` with duplicate rows removed,
optionally only considering certain columns.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([ \
Row(name='Alice', age=5, height=80), \
Row(name='Alice', age=5, height=80), \
Row(name='Alice', age=10, height=80)]).toDF()
>>> df.dropDuplicates().show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 5| 80|Alice|
| 10| 80|Alice|
+---+------+-----+
>>> df.dropDuplicates(['name', 'height']).show()
+---+------+-----+
|age|height| name|
+---+------+-----+
| 5| 80|Alice|
+---+------+-----+
"""
if subset is None:
jdf = self._jdf.dropDuplicates()
else:
jdf = self._jdf.dropDuplicates(self._jseq(subset))
return DataFrame(jdf, self.sql_ctx)

def dropna(self, how='any', thresh=None, subset=None):
"""Returns a new :class:`DataFrame` omitting rows with null values.
Expand Down Expand Up @@ -1012,6 +1040,10 @@ def toPandas(self):
import pandas as pd
return pd.DataFrame.from_records(self.collect(), columns=self.columns)

# Pandas compatibility
groupby = groupBy
drop_duplicates = dropDuplicates


# Having SchemaRDD for backward compatibility (for docs)
class SchemaRDD(DataFrame):
Expand Down
38 changes: 36 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.sql.DriverManager


import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.ClassTag
Expand All @@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
import org.apache.spark.sql.json.JacksonGenerator
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -932,6 +931,40 @@ class DataFrame private[sql](
}
}

/**
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
* This is an alias for `distinct`.
* @group dfops
*/
def dropDuplicates(): DataFrame = dropDuplicates(this.columns)

/**
* (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only
* the subset of columns.
*
* @group dfops
*/
def dropDuplicates(colNames: Seq[String]): DataFrame = {
val groupCols = colNames.map(resolve)
val groupColExprIds = groupCols.map(_.exprId)
val aggCols = logicalPlan.output.map { attr =>
if (groupColExprIds.contains(attr.exprId)) {
attr
} else {
Alias(First(attr), attr.name)()
}
}
Aggregate(groupCols, aggCols, logicalPlan)
}

/**
* Returns a new [[DataFrame]] with duplicate rows removed, considering only
* the subset of columns.
*
* @group dfops
*/
def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq)

/**
* Computes statistics for numeric columns, including count, mean, stddev, min, and max.
* If no columns are given, this function computes statistics for all numerical columns.
Expand Down Expand Up @@ -1089,6 +1122,7 @@ class DataFrame private[sql](

/**
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
* This is an alias for `dropDuplicates`.
* @group dfops
*/
override def distinct: DataFrame = Distinct(logicalPlan)
Expand Down
35 changes: 35 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -457,4 +457,39 @@ class DataFrameSuite extends QueryTest {
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
}

test("SPARK-7324 dropDuplicates") {
val testData = TestSQLContext.sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
(2, 1, 1) :: (1, 1, 2) ::
(1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2")

checkAnswer(
testData.dropDuplicates(),
Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1),
Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1),
Row(1, 1, 2), Row(1, 2, 2)))

checkAnswer(
testData.dropDuplicates(Seq("key", "value1")),
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))

checkAnswer(
testData.dropDuplicates(Seq("value1", "value2")),
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))

checkAnswer(
testData.dropDuplicates(Seq("key")),
Seq(Row(2, 1, 2), Row(1, 1, 1)))

checkAnswer(
testData.dropDuplicates(Seq("value1")),
Seq(Row(2, 1, 2), Row(1, 2, 1)))

checkAnswer(
testData.dropDuplicates(Seq("value2")),
Seq(Row(2, 1, 2), Row(1, 1, 1)))
}
}

0 comments on commit 130692f

Please sign in to comment.