From 8ac09108fcf3fb62a812333a5b386b566a9d98ec Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 1 Nov 2016 10:46:36 -0700 Subject: [PATCH 1/6] [SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit ## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #15414 from zhengruifeng/move_cast. --- .../scala/org/apache/spark/ml/Predictor.scala | 12 ++- .../spark/ml/classification/Classifier.scala | 4 +- .../ml/classification/GBTClassifier.scala | 2 +- .../classification/LogisticRegression.scala | 2 +- .../spark/ml/classification/NaiveBayes.scala | 2 +- .../GeneralizedLinearRegression.scala | 2 +- .../ml/regression/LinearRegression.scala | 2 +- .../org/apache/spark/ml/PredictorSuite.scala | 82 +++++++++++++++++++ .../LogisticRegressionSuite.scala | 1 - 9 files changed, 98 insertions(+), 11 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6b..aa92edde7acd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f2342..a3da3067e1b5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda0327..f8f164e8c14bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec..c4651054fd765 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99df..b03a07a6bc1e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f66..8656ecf609ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82df..ae876b3839734 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 0000000000000..03e0c536a973e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d3149..8771fd2e9d2b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { From 8cdf143f4b1ca5c6bc0256808e6f42d9ef299cbd Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 1 Nov 2016 11:17:35 -0700 Subject: [PATCH 2/6] [SPARK-18103][FOLLOW-UP][SQL][MINOR] Rename `MetadataLogFileCatalog` to `MetadataLogFileIndex` ## What changes were proposed in this pull request? This is a follow-up to https://github.com/apache/spark/pull/15634. ## How was this patch tested? N/A Author: Liwei Lin Closes #15712 from lw-lin/18103. --- .../{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{MetadataLogFileCatalog.scala => MetadataLogFileIndex.scala} (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala From 8a538c97b556f80f67c80519af0ce879557050d5 Mon Sep 17 00:00:00 2001 From: Ergin Seyfe Date: Tue, 1 Nov 2016 11:18:42 -0700 Subject: [PATCH 3/6] [SPARK-18189][SQL] Fix serialization issue in KeyValueGroupedDataset ## What changes were proposed in this pull request? Likewise [DataSet.scala](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L156) KeyValueGroupedDataset should mark the queryExecution as transient. As mentioned in the Jira ticket, without transient we saw serialization issues like ``` Caused by: java.io.NotSerializableException: org.apache.spark.sql.execution.QueryExecution Serialization stack: - object not serializable (class: org.apache.spark.sql.execution.QueryExecution, value: == ``` ## How was this patch tested? Run the query which is specified in the Jira ticket before and after: ``` val a = spark.createDataFrame(sc.parallelize(Seq((1,2),(3,4)))).as[(Int,Int)] val grouped = a.groupByKey( {x:(Int,Int)=>x._1} ) val mappedGroups = grouped.mapGroups((k,x)=> {(k,1)} ) val yyy = sc.broadcast(1) val last = mappedGroups.rdd.map(xx=> { val simpley = yyy.value 1 } ) ``` Author: Ergin Seyfe Closes #15706 from seyfe/keyvaluegrouped_serialization. --- .../scala/org/apache/spark/repl/ReplSuite.scala | 17 +++++++++++++++++ .../spark/sql/KeyValueGroupedDataset.scala | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 9262e938c2a60..96d2dfc2658b9 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,21 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") { + val resultValue = 12345 + val output = runInterpreter("local", + s""" + |val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1) + |val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1)) + |val broadcasted = sc.broadcast($resultValue) + | + |// Using broadcast triggers serialization issue in KeyValueGroupedDataset + |val dataset = mapGroups.map(_ => broadcasted.value) + |dataset.collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains(s": Array[Int] = Array($resultValue, $resultValue)", output) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 4cb0313aa9037..31ce8eb25e808 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.expressions.ReduceAggregator class KeyValueGroupedDataset[K, V] private[sql]( kEncoder: Encoder[K], vEncoder: Encoder[V], - val queryExecution: QueryExecution, + @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { From d0272b436512b71f04313e109d3d21a6e9deefca Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 1 Nov 2016 11:25:11 -0700 Subject: [PATCH 4/6] [SPARK-18148][SQL] Misleading Error Message for Aggregation Without Window/GroupBy ## What changes were proposed in this pull request? Aggregation Without Window/GroupBy expressions will fail in `checkAnalysis`, the error message is a bit misleading, we should generate a more specific error message for this case. For example, ``` spark.read.load("/some-data") .withColumn("date_dt", to_date($"date")) .withColumn("year", year($"date_dt")) .withColumn("week", weekofyear($"date_dt")) .withColumn("user_count", count($"userId")) .withColumn("daily_max_in_week", max($"user_count").over(weeklyWindow)) ) ``` creates the following output: ``` org.apache.spark.sql.AnalysisException: expression '`randomColumn`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; ``` In the error message above, `randomColumn` doesn't appear in the query(acturally it's added by function `withColumn`), so the message is not enough for the user to address the problem. ## How was this patch tested? Manually test Before: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: expression 'tbl.`col`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; ``` After: ``` scala> spark.sql("select col, count(col) from tbl") org.apache.spark.sql.AnalysisException: grouping expressions sequence is empty, and 'tbl.`col`' is not an aggregate function. Wrap '(count(col#231L) AS count(col)#239L)' in windowing function(s) or wrap 'tbl.`col`' in first() (or first_value) if you don't care which value you get.;; ``` Also add new test sqls in `group-by.sql`. Author: jiangxingbo Closes #15672 from jiangxb1987/groupBy-empty. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++ .../resources/sql-tests/inputs/group-by.sql | 41 +++++-- .../sql-tests/results/group-by.sql.out | 116 +++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ------ 4 files changed, 140 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9a7c2a944b588..3455a567b7786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,6 +214,18 @@ trait CheckAnalysis extends PredicateHelper { s"appear in the arguments of an aggregate function.") } } + case e: Attribute if groupingExprs.isEmpty => + // Collect all [[AggregateExpressions]]s. + val aggExprs = aggregateExprs.filter(_.collect { + case a: AggregateExpression => a + }.nonEmpty) + failAnalysis( + s"grouping expressions sequence is empty, " + + s"and '${e.sql}' is not an aggregate function. " + + s"Wrap '${aggExprs.map(_.sql).mkString("(", ", ", ")")}' in windowing " + + s"function(s) or wrap '${e.sql}' in first() (or first_value) " + + s"if you don't care which value you get." + ) case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.sql}' is neither present in the group by, " + diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 6741703d9d82c..d950ec83d98c3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -1,17 +1,34 @@ --- Temporary data. -create temporary view myview as values 128, 256 as v(int_col); +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b); --- group by should produce all input rows, -select int_col, count(*) from myview group by int_col; +-- Aggregate with empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData; +SELECT COUNT(a), COUNT(b) FROM testData; --- group by should produce a single row. -select 'foo', count(*) from myview group by 1; +-- Aggregate with non-empty GroupBy expressions. +SELECT a, COUNT(b) FROM testData GROUP BY a; +SELECT a, COUNT(b) FROM testData GROUP BY b; +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a; --- group-by should not produce any rows (whole stage code generation). -select 'foo' from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals. +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1; --- group-by should not produce any rows (hash aggregate). -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (whole stage code generation). +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1; --- group-by should not produce any rows (sort aggregate). -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; +-- Aggregate grouped by literals (hash aggregate). +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate grouped by literals (sort aggregate). +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1; + +-- Aggregate with complex GroupBy expressions. +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b; +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1; +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1; + +-- Aggregate with nulls. +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 9127bd4dd4c6f..a91f04e098b18 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,9 +1,11 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 14 -- !query 0 -create temporary view myview as values 128, 256 as v(int_col) +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (null, 1), (3, null), (null, null) +AS testData(a, b) -- !query 0 schema struct<> -- !query 0 output @@ -11,41 +13,121 @@ struct<> -- !query 1 -select int_col, count(*) from myview group by int_col +SELECT a, COUNT(b) FROM testData -- !query 1 schema -struct +struct<> -- !query 1 output -128 1 -256 1 +org.apache.spark.sql.AnalysisException +grouping expressions sequence is empty, and 'testdata.`a`' is not an aggregate function. Wrap '(count(testdata.`b`) AS `count(b)`)' in windowing function(s) or wrap 'testdata.`a`' in first() (or first_value) if you don't care which value you get.; -- !query 2 -select 'foo', count(*) from myview group by 1 +SELECT COUNT(a), COUNT(b) FROM testData -- !query 2 schema -struct +struct -- !query 2 output -foo 2 +7 7 -- !query 3 -select 'foo' from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY a -- !query 3 schema -struct +struct -- !query 3 output - +1 2 +2 2 +3 2 +NULL 1 -- !query 4 -select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +SELECT a, COUNT(b) FROM testData GROUP BY b -- !query 4 schema -struct +struct<> -- !query 4 output - +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; -- !query 5 -select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +SELECT COUNT(a), COUNT(b) FROM testData GROUP BY a -- !query 5 schema -struct> +struct -- !query 5 output +0 1 +2 2 +2 2 +3 2 + + +-- !query 6 +SELECT 'foo', COUNT(a) FROM testData GROUP BY 1 +-- !query 6 schema +struct +-- !query 6 output +foo 7 + + +-- !query 7 +SELECT 'foo' FROM testData WHERE a = 0 GROUP BY 1 +-- !query 7 schema +struct +-- !query 7 output + + +-- !query 8 +SELECT 'foo', APPROX_COUNT_DISTINCT(a) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 8 schema +struct +-- !query 8 output + + + +-- !query 9 +SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 +-- !query 9 schema +struct> +-- !query 9 output + + + +-- !query 10 +SELECT a + b, COUNT(b) FROM testData GROUP BY a + b +-- !query 10 schema +struct<(a + b):int,count(b):bigint> +-- !query 10 output +2 1 +3 2 +4 2 +5 1 +NULL 1 + + +-- !query 11 +SELECT a + 2, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +expression 'testdata.`a`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 12 +SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1 +-- !query 12 schema +struct<((a + 1) + 1):int,count(b):bigint> +-- !query 12 output +3 2 +4 2 +5 2 +NULL 1 + + +-- !query 13 +SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) +FROM testData +-- !query 13 schema +struct +-- !query 13 output +-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1a43d0b2205ca..9a3d93cf17b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,20 +463,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - test("agg") { - checkAnswer( - sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1, 3), Row(2, 3), Row(3, 3))) - } - - test("aggregates with nulls") { - checkAnswer( - sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + - "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), - Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3) - ) - } - test("select *") { checkAnswer( sql("SELECT * FROM testData"), @@ -1178,27 +1164,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1)) } - test("throw errors for non-aggregate attributes with aggregation") { - def checkAggregation(query: String, isInvalidQuery: Boolean = true) { - if (isInvalidQuery) { - val e = intercept[AnalysisException](sql(query).queryExecution.analyzed) - assert(e.getMessage contains "group by") - } else { - // Should not throw - sql(query).queryExecution.analyzed - } - } - - checkAggregation("SELECT key, COUNT(*) FROM testData") - checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", isInvalidQuery = false) - - checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") - checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) - - checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") - checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) - } - testQuietly( "SPARK-16748: SparkExceptions during planning should not wrapped in TreeNodeException") { intercept[SparkException] { From cfac17ee1cec414663b957228e469869eb7673c1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 1 Nov 2016 12:35:34 -0700 Subject: [PATCH 5/6] [SPARK-18167] Disable flaky SQLQuerySuite test We now know it's a persistent environmental issue that is causing this test to sometimes fail. One hypothesis is that some configuration is leaked from another suite, and depending on suite ordering this can cause this test to fail. I am planning on mining the jenkins logs to try to narrow down which suite could be causing this. For now, disable the test. Author: Eric Liang Closes #15720 from ericl/disable-flaky-test. --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8b916932ff543..b9353b5b5d2a7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1565,7 +1565,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } - test("SPARK-10562: partition by column with mixed case name") { + ignore("SPARK-10562: partition by column with mixed case name") { def runOnce() { withTable("tbl10562") { val df = Seq(2012 -> "a").toDF("Year", "val") From 01dd0083011741c2bbe5ae1d2a25f2c9a1302b76 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Tue, 1 Nov 2016 12:46:41 -0700 Subject: [PATCH 6/6] [SPARK-17764][SQL] Add `to_json` supporting to convert nested struct column to JSON string ## What changes were proposed in this pull request? This PR proposes to add `to_json` function in contrast with `from_json` in Scala, Java and Python. It'd be useful if we can convert a same column from/to json. Also, some datasources do not support nested types. If we are forced to save a dataframe into those data sources, we might be able to work around by this function. The usage is as below: ``` scala val df = Seq(Tuple1(Tuple1(1))).toDF("a") df.select(to_json($"a").as("json")).show() ``` ``` bash +--------+ | json| +--------+ |{"_1":1}| +--------+ ``` ## How was this patch tested? Unit tests in `JsonFunctionsSuite` and `JsonExpressionsSuite`. Author: hyukjinkwon Closes #15354 from HyukjinKwon/SPARK-17764. --- python/pyspark/sql/functions.py | 23 +++++++++ python/pyspark/sql/readwriter.py | 2 +- python/pyspark/sql/streaming.py | 2 +- .../expressions/jsonExpressions.scala | 48 ++++++++++++++++++- .../sql/catalyst}/json/JacksonGenerator.scala | 5 +- .../sql/catalyst/json/JacksonUtils.scala | 26 ++++++++++ .../expressions/JsonExpressionsSuite.scala | 9 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../datasources/json/JsonFileFormat.scala | 2 +- .../org/apache/spark/sql/functions.scala | 44 ++++++++++++++++- .../apache/spark/sql/JsonFunctionsSuite.scala | 30 +++++++++--- 11 files changed, 177 insertions(+), 16 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/datasources => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/json/JacksonGenerator.scala (98%) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7fa3fd2de7ddf..45e3c22bfc6a9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1744,6 +1744,29 @@ def from_json(col, schema, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.1) +def to_json(col, options={}): + """ + Converts a column containing a [[StructType]] into a JSON string. Throws an exception, + in the case of an unsupported type. + + :param col: name of column containing the struct + :param options: options to control converting. accepts the same options as the json datasource + + >>> from pyspark.sql import Row + >>> from pyspark.sql.types import * + >>> data = [(1, Row(name='Alice', age=2))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'{"age":2,"name":"Alice"}')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.to_json(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bc786ef95ed03..b0c51b1e9992e 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ Loads a JSON file (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per + `_) or an RDD of Strings storing JSON objects (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 559647bbabf67..1c94413e3c457 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -641,7 +641,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, timestampFormat=None): """ Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON - <[http://jsonlines.org/>`_) and returns a :class`DataFrame`. + `_) and returns a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 65dbd6a4e3f1d..244a5a34f3594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, StringWriter} +import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException} +import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.ParseModes import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -494,3 +495,46 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child: override def inputTypes: Seq[AbstractDataType] = StringType :: Nil } + +/** + * Converts a [[StructType]] to a json output string. + */ +case class StructToJson(options: Map[String, String], child: Expression) + extends Expression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val writer = new CharArrayWriter() + + @transient + lazy val gen = + new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer) + + override def dataType: DataType = StringType + override def children: Seq[Expression] = child :: Nil + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + override def eval(input: InternalRow): Any = { + gen.write(child.eval(input).asInstanceOf[InternalRow]) + gen.flush() + val json = writer.toString + writer.reset() + UTF8String.fromString(json) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 5b55b701862b7..4b548e0e7f978 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.json +package org.apache.spark.sql.catalyst.json import java.io.Writer import com.fasterxml.jackson.core._ -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.JSONOptions +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index c4d9abb2c07e8..3b23c6cd2816f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.json import com.fasterxml.jackson.core.{JsonParser, JsonToken} +import org.apache.spark.sql.types._ + object JacksonUtils { /** * Advance the parser until a null or a specific token is found @@ -29,4 +31,28 @@ object JacksonUtils { case x => x != stopOn } } + + /** + * Verify if the schema is supported in JSON parsing. + */ + def verifySchema(schema: StructType): Unit = { + def verifyType(name: String, dataType: DataType): Unit = dataType match { + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType => + + case st: StructType => st.foreach(field => verifyType(field.name, field.dataType)) + + case at: ArrayType => verifyType(name, at.elementType) + + case mt: MapType => verifyType(name, mt.keyType) + + case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"Unable to convert column $name of type ${dataType.simpleString} to JSON.") + } + + schema.foreach(field => verifyType(field.name, field.dataType)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 84623934d95d2..f9db649bc2404 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -343,4 +343,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null ) } + + test("to_json") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), schema) + checkEvaluation( + StructToJson(Map.empty, struct), + """{"a":1}""" + ) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e0a2471e0fb5..eb2b20afc37cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.json.JacksonGenerator import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -45,7 +46,6 @@ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 5a409c04c929d..0e38aefecb673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextOutputWriter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5f1efd22d8204..944a476114faf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2883,10 +2883,10 @@ object functions { * (Scala-specific) Parses a column containing a JSON string into a [[StructType]] with the * specified schema. Returns `null`, in the case of an unparseable string. * + * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string * @param options options to control how the json is parsed. accepts the same options and the * json data source. - * @param e a string column containing JSON data. * * @group collection_funcs * @since 2.1.0 @@ -2936,6 +2936,48 @@ object functions { def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) + + /** + * (Scala-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: Map[String, String]): Column = withExpr { + StructToJson(options, e.expr) + } + + /** + * (Java-specific) Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column, options: java.util.Map[String, String]): Column = + to_json(e, options.asScala.toMap) + + /** + * Converts a column containing a [[StructType]] into a JSON string with the + * specified schema. Throws an exception, in the case of an unsupported type. + * + * @param e a struct column. + * + * @group collection_funcs + * @since 2.1.0 + */ + def to_json(e: Column): Column = + to_json(e, Map.empty[String, String]) + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 518d6e92b2ff7..59ae889cf3b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.from_json +import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType} class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -31,7 +31,6 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("alice", "5")) } - val tuples: Seq[(String, String)] = ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: @@ -97,7 +96,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(expr, expected) } - test("json_parser") { + test("from_json") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -106,7 +105,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(1)) :: Nil) } - test("json_parser missing columns") { + test("from_json missing columns") { val df = Seq("""{"a": 1}""").toDS() val schema = new StructType().add("b", IntegerType) @@ -115,7 +114,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(Row(null)) :: Nil) } - test("json_parser invalid json") { + test("from_json invalid json") { val df = Seq("""{"a" 1}""").toDS() val schema = new StructType().add("a", IntegerType) @@ -123,4 +122,23 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { df.select(from_json($"value", schema)), Row(null) :: Nil) } + + test("to_json") { + val df = Seq(Tuple1(Tuple1(1))).toDF("a") + + checkAnswer( + df.select(to_json($"a")), + Row("""{"_1":1}""") :: Nil) + } + + test("to_json unsupported type") { + val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") + .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) + val e = intercept[AnalysisException]{ + // Unsupported type throws an exception + df.select(to_json($"c")).collect() + } + assert(e.getMessage.contains( + "Unable to convert column a of type calendarinterval to JSON.")) + } }