From 1974c033ac1de4705031018791d5deb46d893443 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 1 Feb 2017 14:11:28 -0800 Subject: [PATCH] [SPARK-14352][SQL] approxQuantile should support multi columns ## What changes were proposed in this pull request? 1, add the multi-cols support based on current private api 2, add the multi-cols support to pyspark ## How was this patch tested? unit tests Author: Zheng RuiFeng Author: Ruifeng Zheng Closes #12135 from zhengruifeng/quantile4multicols. --- python/pyspark/sql/dataframe.py | 37 +++++++++++++++---- python/pyspark/sql/tests.py | 23 +++++++++++- .../spark/sql/DataFrameStatFunctions.scala | 37 +++++++++++++++++-- .../apache/spark/sql/DataFrameStatSuite.scala | 15 ++++++++ 4 files changed, 101 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 10e42d0f9d322..50373b8585195 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -16,7 +16,6 @@ # import sys -import warnings import random if sys.version >= '3': @@ -1348,7 +1347,7 @@ def replace(self, to_replace, value, subset=None): @since(2.0) def approxQuantile(self, col, probabilities, relativeError): """ - Calculates the approximate quantiles of a numerical column of a + Calculates the approximate quantiles of numerical columns of a DataFrame. The result of this algorithm has the following deterministic bound: @@ -1365,7 +1364,10 @@ def approxQuantile(self, col, probabilities, relativeError): Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. - :param col: the name of the numerical column + Note that rows containing any null values will be removed before calculation. + + :param col: str, list. + Can be a single column name, or a list of names for multiple columns. :param probabilities: a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the minimum, 0.5 is the median, 1 is the maximum. @@ -1373,10 +1375,30 @@ def approxQuantile(self, col, probabilities, relativeError): (>= 0). If set to zero, the exact quantiles are computed, which could be very expensive. Note that values greater than 1 are accepted but give the same result as 1. - :return: the approximate quantiles at the given probabilities + :return: the approximate quantiles at the given probabilities. If + the input `col` is a string, the output is a list of floats. If the + input `col` is a list or tuple of strings, the output is also a + list, but each element in it is a list of floats, i.e., the output + is a list of list of floats. + + .. versionchanged:: 2.2 + Added support for multiple columns. """ - if not isinstance(col, str): - raise ValueError("col should be a string.") + + if not isinstance(col, (str, list, tuple)): + raise ValueError("col should be a string, list or tuple, but got %r" % type(col)) + + isStr = isinstance(col, str) + + if isinstance(col, tuple): + col = list(col) + elif isinstance(col, str): + col = [col] + + for c in col: + if not isinstance(c, str): + raise ValueError("columns should be strings, but got %r" % type(c)) + col = _to_list(self._sc, col) if not isinstance(probabilities, (list, tuple)): raise ValueError("probabilities should be a list or tuple") @@ -1392,7 +1414,8 @@ def approxQuantile(self, col, probabilities, relativeError): relativeError = float(relativeError) jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) - return list(jaq) + jaq_list = [list(j) for j in jaq] + return jaq_list[0] if isStr else jaq_list @since(1.4) def corr(self, col1, col2, method=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2fea4ac41f0d3..86cad4b363c4c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -895,11 +895,32 @@ def test_first_last_ignorenulls(self): self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) def test_approxQuantile(self): - df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF() + df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aq, list)) self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) + aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aqs, list)) + self.assertEqual(len(aqs), 2) + self.assertTrue(isinstance(aqs[0], list)) + self.assertEqual(len(aqs[0]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[0])) + self.assertTrue(isinstance(aqs[1], list)) + self.assertEqual(len(aqs[1]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[1])) + aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aqt, list)) + self.assertEqual(len(aqt), 2) + self.assertTrue(isinstance(aqt[0], list)) + self.assertEqual(len(aqt[0]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqt[0])) + self.assertTrue(isinstance(aqt[1], list)) + self.assertEqual(len(aqt[1]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqt[1])) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1)) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1)) + self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1)) def test_corr(self): import math diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 72945320614bf..2b782fd75c6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.stat._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types._ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} @@ -74,14 +75,44 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { Seq(col), probabilities, relativeError).head.toArray } + /** + * Calculates the approximate quantiles of numerical columns of a DataFrame. + * @see [[DataFrameStatsFunctions.approxQuantile(col:Str* approxQuantile]] for + * detailed description. + * + * Note that rows containing any null or NaN values values will be removed before + * calculation. + * @param cols the names of the numerical columns + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities of each column + * + * @note Rows containing any NaN values will be removed before calculation + * + * @since 2.2.0 + */ + def approxQuantile( + cols: Array[String], + probabilities: Array[Double], + relativeError: Double): Array[Array[Double]] = { + StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols, + probabilities, relativeError).map(_.toArray).toArray + } + + /** * Python-friendly version of [[approxQuantile()]] */ private[spark] def approxQuantile( - col: String, + cols: List[String], probabilities: List[Double], - relativeError: Double): java.util.List[Double] = { - approxQuantile(col, probabilities.toArray, relativeError).toList.asJava + relativeError: Double): java.util.List[java.util.List[Double]] = { + approxQuantile(cols.toArray, probabilities.toArray, relativeError) + .map(_.toList.asJava).toList.asJava } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 1383208874a19..f52b18e27b5f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -149,11 +149,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(s2 - q2 * n) < error_single) assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) + + // Multiple columns + val Array(Array(ms1, ms2), Array(md1, md2)) = + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) + + assert(math.abs(ms1 - q1 * n) < error_single) + assert(math.abs(ms2 - q2 * n) < error_single) + assert(math.abs(md1 - 2 * q1 * n) < error_double) + assert(math.abs(md2 - 2 * q2 * n) < error_double) } // test approxQuantile on NaN values val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) assert(resNaN.count(_.isNaN) === 0) + // test approxQuantile on multi-column NaN values + val dfNaN2 = Seq((Double.NaN, 1.0), (1.0, 1.0), (-1.0, Double.NaN), (Double.NaN, Double.NaN)) + .toDF("input1", "input2") + val resNaN2 = dfNaN2.stat.approxQuantile(Array("input1", "input2"), + Array(q1, q2), epsilons.head) + assert(resNaN2.flatten.count(_.isNaN) === 0) } test("crosstab") {