Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache-github/master' into SPARK-16839-2
Browse files Browse the repository at this point in the history
# Conflicts:
#	sql/core/src/test/resources/sql-tests/inputs/group-by.sql
#	sql/core/src/test/resources/sql-tests/results/group-by.sql.out
  • Loading branch information
hvanhovell committed Nov 1, 2016
2 parents 29ccf4e + 01dd008 commit c0263d7
Show file tree
Hide file tree
Showing 28 changed files with 434 additions and 93 deletions.
12 changes: 9 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params

/**
* :: DeveloperApi ::
* Abstraction for prediction problems (regression and classification).
* Abstraction for prediction problems (regression and classification). It accepts all NumericType
* labels and will automatically cast it to DoubleType in [[fit()]].
*
* @tparam FeaturesType Type of features.
* E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
Expand Down Expand Up @@ -87,7 +88,12 @@ abstract class Predictor[
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
copyValues(train(dataset).setParent(this))

// Cast LabelCol to DoubleType and keep the metadata.
val labelMeta = dataset.schema($(labelCol)).metadata
val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)

copyValues(train(casted).setParent(this))
}

override def copy(extra: ParamMap): Learner
Expand Down Expand Up @@ -121,7 +127,7 @@ abstract class Predictor[
* and put it in an RDD with strong types.
*/
protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ abstract class Classifier[
* and put it in an RDD with strong types.
*
* @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]])
* and features ([[Vector]]). Labels are cast to [[DoubleType]].
* and features ([[Vector]]).
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
* @throws SparkException if any label is not an integer >= 0
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class GBTClassifier @Since("1.4.0") (
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
val oldDataset: RDD[LabeledPoint] =
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
s" dataset with invalid label $label. Labels must be in {0,1}; note that" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") (
LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class NaiveBayes @Since("1.5.0") (
// Aggregates term frequencies per label.
// TODO: Calling aggregateByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd
val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
.map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
}.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(
seqOp = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val

val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))

val instances: RDD[Instance] = dataset.select(
col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Expand Down
82 changes: 82 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {

import PredictorSuite._

test("should support all NumericType labels and not support other types") {
val df = spark.createDataFrame(Seq(
(0, Vectors.dense(0, 2, 3)),
(1, Vectors.dense(0, 3, 9)),
(0, Vectors.dense(0, 2, 6))
)).toDF("label", "features")

val types =
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))

val predictor = new MockPredictor()

types.foreach { t =>
predictor.fit(df.select(col("label").cast(t), col("features")))
}

intercept[IllegalArgumentException] {
predictor.fit(df.select(col("label").cast(StringType), col("features")))
}
}
}

object PredictorSuite {

class MockPredictor(override val uid: String)
extends Predictor[Vector, MockPredictor, MockPredictionModel] {

def this() = this(Identifiable.randomUID("mockpredictor"))

override def train(dataset: Dataset[_]): MockPredictionModel = {
require(dataset.schema("label").dataType == DoubleType)
new MockPredictionModel(uid)
}

override def copy(extra: ParamMap): MockPredictor =
throw new NotImplementedError()
}

class MockPredictionModel(override val uid: String)
extends PredictionModel[Vector, MockPredictionModel] {

def this() = this(Identifiable.randomUID("mockpredictormodel"))

override def predict(features: Vector): Double =
throw new NotImplementedError()

override def copy(extra: ParamMap): MockPredictionModel =
throw new NotImplementedError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,6 @@ class LogisticRegressionSuite
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))

}

test("binary logistic regression with weighted data") {
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,29 @@ def from_json(col, schema, options={}):
return Column(jc)


@ignore_unicode_prefix
@since(2.1)
def to_json(col, options={}):
"""
Converts a column containing a [[StructType]] into a JSON string. Throws an exception,
in the case of an unsupported type.
:param col: name of column containing the struct
:param options: options to control converting. accepts the same options as the json datasource
>>> from pyspark.sql import Row
>>> from pyspark.sql.types import *
>>> data = [(1, Row(name='Alice', age=2))]
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(to_json(df.value).alias("json")).collect()
[Row(json=u'{"age":2,"name":"Alice"}')]
"""

sc = SparkContext._active_spark_context
jc = sc._jvm.functions.to_json(_to_java_column(col), options)
return Column(jc)


@since(1.5)
def size(col):
"""
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None):
"""
Loads a JSON file (`JSON Lines text format or newline-delimited JSON
<[http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
<http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
record) and returns the result as a :class`DataFrame`.
If the ``schema`` parameter is not specified, this function goes
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
timestampFormat=None):
"""
Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
<[http://jsonlines.org/>`_) and returns a :class`DataFrame`.
<http://jsonlines.org/>`_) and returns a :class`DataFrame`.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,4 +473,21 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("AssertionError", output)
assertDoesNotContain("Exception", output)
}

test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
val resultValue = 12345
val output = runInterpreter("local",
s"""
|val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
|val mapGroups = keyValueGrouped.mapGroups((k, v) => (k, 1))
|val broadcasted = sc.broadcast($resultValue)
|
|// Using broadcast triggers serialization issue in KeyValueGroupedDataset
|val dataset = mapGroups.map(_ => broadcasted.value)
|dataset.collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains(s": Array[Int] = Array($resultValue, $resultValue)", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ trait CheckAnalysis extends PredicateHelper {
s"appear in the arguments of an aggregate function.")
}
}
case e: Attribute if groupingExprs.isEmpty =>
// Collect all [[AggregateExpressions]]s.
val aggExprs = aggregateExprs.filter(_.collect {
case a: AggregateExpression => a
}.nonEmpty)
failAnalysis(
s"grouping expressions sequence is empty, " +
s"and '${e.sql}' is not an aggregate function. " +
s"Wrap '${aggExprs.map(_.sql).mkString("(", ", ", ")")}' in windowing " +
s"function(s) or wrap '${e.sql}' in first() (or first_value) " +
s"if you don't care which value you get."
)
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.sql}' is neither present in the group by, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.{ByteArrayOutputStream, StringWriter}
import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter}

import scala.util.parsing.combinator.RegexParsers

import com.fasterxml.jackson.core._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions, SparkSQLJsonProcessingException}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -494,3 +495,46 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:

override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
}

/**
* Converts a [[StructType]] to a json output string.
*/
case class StructToJson(options: Map[String, String], child: Expression)
extends Expression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

@transient
lazy val writer = new CharArrayWriter()

@transient
lazy val gen =
new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer)

override def dataType: DataType = StringType
override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = {
if (StructType.acceptsType(child.dataType)) {
try {
JacksonUtils.verifySchema(child.dataType.asInstanceOf[StructType])
TypeCheckResult.TypeCheckSuccess
} catch {
case e: UnsupportedOperationException =>
TypeCheckResult.TypeCheckFailure(e.getMessage)
}
} else {
TypeCheckResult.TypeCheckFailure(
s"$prettyName requires that the expression is a struct expression.")
}
}

override def eval(input: InternalRow): Any = {
gen.write(child.eval(input).asInstanceOf[InternalRow])
gen.flush()
val json = writer.toString
writer.reset()
UTF8String.fromString(json)
}

override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.json
package org.apache.spark.sql.catalyst.json

import java.io.Writer

import com.fasterxml.jackson.core._

import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.types._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.json

import com.fasterxml.jackson.core.{JsonParser, JsonToken}

import org.apache.spark.sql.types._

object JacksonUtils {
/**
* Advance the parser until a null or a specific token is found
Expand All @@ -29,4 +31,28 @@ object JacksonUtils {
case x => x != stopOn
}
}

/**
* Verify if the schema is supported in JSON parsing.
*/
def verifySchema(schema: StructType): Unit = {
def verifyType(name: String, dataType: DataType): Unit = dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType |
DoubleType | StringType | TimestampType | DateType | BinaryType | _: DecimalType =>

case st: StructType => st.foreach(field => verifyType(field.name, field.dataType))

case at: ArrayType => verifyType(name, at.elementType)

case mt: MapType => verifyType(name, mt.keyType)

case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)

case _ =>
throw new UnsupportedOperationException(
s"Unable to convert column $name of type ${dataType.simpleString} to JSON.")
}

schema.foreach(field => verifyType(field.name, field.dataType))
}
}
Loading

0 comments on commit c0263d7

Please sign in to comment.