diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c421006c8fd2d..a86de847dfd79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -887,6 +887,20 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] without duplicates. + * @group dfops + */ + def dropDuplicates(subset: Seq[String] = this.columns): DataFrame = { + import org.apache.spark.sql.functions.{first => columnFirst} + if (subset.length == 0) { + sqlContext.emptyDataFrame + } else { + val columnFirsts = columns.map(columnFirst) + groupBy(subset.head, subset.tail: _*).agg(columnFirsts.head, columnFirsts.tail: _*) + } + } + /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. * If no columns are given, this function computes statistics for all numerical columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e286fef23caa4..06e4983ba49ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -613,4 +613,39 @@ class DataFrameSuite extends QueryTest { Row(new java.math.BigDecimal(2.0))) TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + + test("SPARK-7324 dropDuplicates") { + val testData = TestSQLContext.sparkContext.parallelize( + (2, 1, 2) :: (1, 1, 1) :: + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + + checkAnswer( + testData.dropDuplicates(), + Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1), + Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1), + Row(1, 1, 2), Row(1, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key", "value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("value1", "value2")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value2")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + } }