From 83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 19:22:38 -0700 Subject: [PATCH 01/43] [SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc This PR is based on #6988 , thanks to adrian-wang . This brings two SQL functions: to_date() and trunc(). Closes #6988 Author: Daoyuan Wang Author: Davies Liu Closes #7805 from davies/to_date and squashes the following commits: 2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date 310dd55 [Daoyuan Wang] remove dup test in rebase 980b092 [Daoyuan Wang] resolve rebase conflict a476c5a [Daoyuan Wang] address comments from davies d44ea5f [Daoyuan Wang] function to_date, trunc --- python/pyspark/sql/functions.py | 30 +++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/datetimeFunctions.scala | 88 ++++++++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 34 +++++++ .../expressions/DateExpressionsSuite.scala | 29 +++++- .../expressions/NonFoldableLiteral.scala | 4 + .../org/apache/spark/sql/functions.scala | 16 ++++ .../apache/spark/sql/DateFunctionsSuite.scala | 44 ++++++++++ 8 files changed, 245 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a7295e25f0aa5..8024a8de07c98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -888,6 +888,36 @@ def months_between(date1, date2): return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) +@since(1.5) +def to_date(col): + """ + Converts the column of StringType or TimestampType into DateType. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_date(_to_java_column(col))) + + +@since(1.5) +def trunc(date, format): + """ + Returns date truncated to the unit specified by the format. + + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + + >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6c7c481fab8db..1bf7204a2515c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -223,6 +223,8 @@ object FunctionRegistry { expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[ToDate]("to_date"), + expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9795673ee0664..6e7613340c032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression) }) } } - } /** @@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression) }) } } + +/** + * Returns the date part of a timestamp or string. + */ +case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Implicit casting of spark will accept string in both date and timestamp format, as + // well as TimestampType. + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = child.eval(input) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, d => d) + } +} + +/* + * Returns date truncated to the unit specified by the format. + */ +case class TruncDate(date: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + + lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val minItem = if (format.foldable) { + minItemConst + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (minItem == -1) { + // unknown format + null + } else { + val d = date.eval(input) + if (d == null) { + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (minItemConst == -1) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val d = date.gen(ctx) + s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 53abdf6618eac..5a7c25b8d508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -779,4 +779,38 @@ object DateTimeUtils { } date + (lastDayOfMonthInYear - dayInYear) } + + private val TRUNC_TO_YEAR = 1 + private val TRUNC_TO_MONTH = 2 + private val TRUNC_INVALID = -1 + + /** + * Returns the trunc date from original date and trunc level. + * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2. + */ + def truncDate(d: Int, level: Int): Int = { + if (level == TRUNC_TO_YEAR) { + d - DateTimeUtils.getDayInYear(d) + 1 + } else if (level == TRUNC_TO_MONTH) { + d - DateTimeUtils.getDayOfMonth(d) + 1 + } else { + throw new Exception(s"Invalid trunc level: $level") + } + } + + /** + * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, + * TRUNC_INVALID means unsupported truncate level. + */ + def parseTruncLevel(format: UTF8String): Int = { + if (format == null) { + TRUNC_INVALID + } else { + format.toString.toUpperCase match { + case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR + case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH + case _ => TRUNC_INVALID + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 887e43621a941..6c15c05da3094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } + test("function to_date") { + checkEvaluation( + ToDate(Literal(Date.valueOf("2015-07-22"))), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) + checkEvaluation(ToDate(Literal.create(null, DateType)), null) + } + + test("function trunc") { + def testTrunc(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + val date = Date.valueOf("2015-07-22") + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt => + testTrunc(date, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTrunc(date, fmt, Date.valueOf("2015-07-01")) + } + testTrunc(date, "DD", null) + testTrunc(date, null, null) + testTrunc(null, "MON", null) + testTrunc(null, null, null) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" @@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 0559fb80e7fce..31ecf4a9e810a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -47,4 +47,8 @@ object NonFoldableLiteral { val lit = Literal(value) NonFoldableLiteral(lit.value, lit.dataType) } + def create(value: Any, dataType: DataType): NonFoldableLiteral = { + val lit = Literal.create(value, dataType) + NonFoldableLiteral(lit.value, lit.dataType) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 168894d66117d..46dc4605a5ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2181,6 +2181,22 @@ object functions { */ def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + /* + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index b7267c413165a..8c596fad74ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" From 4e5919bfb47a58bcbda90ae01c1bed2128ded983 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 23:02:11 -0700 Subject: [PATCH 02/43] [SPARK-7690] [ML] Multiclass classification Evaluator Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics. Author: Ram Sriharsha Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits: 9bf4ec7 [Ram Sriharsha] fix indentation 3f09a85 [Ram Sriharsha] cleanup doc 16115ae [Ram Sriharsha] code review fixes 032d2a3 [Ram Sriharsha] fix test eec9865 [Ram Sriharsha] Fix Python Indentation 1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690 68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690 54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline --- .../MulticlassClassificationEvaluator.scala | 85 +++++++++++++++++++ ...lticlassClassificationEvaluatorSuite.scala | 28 ++++++ python/pyspark/ml/evaluation.py | 66 ++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 0000000000000..44f779c1908d7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for multiclass classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + } + metric + } + + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..6d8412b0b3701 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 595593a7f2cde..06e809352225b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -214,6 +214,72 @@ def setParams(self, predictionCol="prediction", labelCol="label", kwargs = self.setParams._input_kwargs return self._set(**kwargs) + +@inherit_doc +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Multiclass Classification, which expects two input + columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + """ + super(MulticlassClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) + self.metricName = Param(self, "metricName", + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="f1") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + Sets params for multiclass classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 69b62f76fced18efa35a107c9be4bc22eba72878 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 30 Jul 2015 23:03:48 -0700 Subject: [PATCH 03/43] [SPARK-9214] [ML] [PySpark] support ml.NaiveBayes for Python support ml.NaiveBayes for Python Author: Yanbo Liang Closes #7568 from yanboliang/spark-9214 and squashes the following commits: 5ee3fd6 [Yanbo Liang] fix typos 3ecd046 [Yanbo Liang] fix typos f9c94d1 [Yanbo Liang] change lambda_ to smoothing and fix other issues 180452a [Yanbo Liang] fix typos 7dda1f4 [Yanbo Liang] support ml.NaiveBayes for Python --- .../spark/ml/classification/NaiveBayes.scala | 10 +- .../classification/JavaNaiveBayesSuite.java | 4 +- .../ml/classification/NaiveBayesSuite.scala | 6 +- python/pyspark/ml/classification.py | 116 +++++++++++++++++- 4 files changed, 125 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1f547e4a98af7..5be35fe209291 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * (default = 1.0). * @group param */ - final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.", ParamValidators.gtEq(0)) /** @group getParam */ - final def getLambda: Double = $(lambda) + final def getSmoothing: Double = $(smoothing) /** * The model type which is a string (case-sensitive). @@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ - def setLambda(value: Double): this.type = set(lambda, value) - setDefault(lambda -> 1.0) + def setSmoothing(value: Double): this.type = set(smoothing, value) + setDefault(smoothing -> 1.0) /** * Set the model type using a string (case-sensitive). @@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 09a9fba0c19cf..a700c9cddb206 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -68,7 +68,7 @@ public void naiveBayesDefaultParams() { assert(nb.getLabelCol() == "label"); assert(nb.getFeaturesCol() == "features"); assert(nb.getPredictionCol() == "prediction"); - assert(nb.getLambda() == 1.0); + assert(nb.getSmoothing() == 1.0); assert(nb.getModelType() == "multinomial"); } @@ -89,7 +89,7 @@ public void testNaiveBayes() { }); DataFrame dataset = jsql.createDataFrame(jrdd, schema); - NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 76381a2741296..264bde3703c5f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -58,7 +58,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(nb.getLabelCol === "label") assert(nb.getFeaturesCol === "features") assert(nb.getPredictionCol === "prediction") - assert(nb.getLambda === 1.0) + assert(nb.getSmoothing === 1.0) assert(nb.getModelType === "multinomial") } @@ -75,7 +75,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) @@ -101,7 +101,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5a82bc286d1e8..93ffcd40949b3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -25,7 +25,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel'] + 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', + 'NaiveBayesModel'] @inherit_doc @@ -576,6 +577,119 @@ class GBTClassificationModel(TreeEnsembleModels): """ +@inherit_doc +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Naive Bayes Classifiers. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + >>> model = nb.fit(df) + >>> model.pi + DenseVector([-0.51..., -0.91...]) + >>> model.theta + DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() + >>> model.transform(test0).head().prediction + 1.0 + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + """ + super(NaiveBayes, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.NaiveBayes", self.uid) + #: param for the smoothing parameter. + self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + #: param for the model type. + self.modelType = Param(self, "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) " + + "and bernoulli.") + self._setDefault(smoothing=1.0, modelType="multinomial") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + Sets params for Naive Bayes. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return NaiveBayesModel(java_model) + + def setSmoothing(self, value): + """ + Sets the value of :py:attr:`smoothing`. + """ + self._paramMap[self.smoothing] = value + return self + + def getSmoothing(self): + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + def setModelType(self, value): + """ + Sets the value of :py:attr:`modelType`. + """ + self._paramMap[self.modelType] = value + return self + + def getModelType(self): + """ + Gets the value of modelType or its default value. + """ + return self.getOrDefault(self.modelType) + + +class NaiveBayesModel(JavaModel): + """ + Model fitted by NaiveBayes. + """ + + @property + def pi(self): + """ + log of class priors. + """ + return self._call_java("pi") + + @property + def theta(self): + """ + log of class conditional probabilities. + """ + return self._call_java("theta") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 0244170b66476abc4a39ed609a852f1a6fa455e7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 23:05:58 -0700 Subject: [PATCH 04/43] [SPARK-9152][SQL] Implement code generation for Like and RLike JIRA: https://issues.apache.org/jira/browse/SPARK-9152 This PR implements code generation for `Like` and `RLike`. Author: Liang-Chi Hsieh Closes #7561 from viirya/like_rlike_codegen and squashes the following commits: fe5641b [Liang-Chi Hsieh] Add test for NonFoldableLiteral. ccd1b43 [Liang-Chi Hsieh] For comments. 0086723 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 50df9a8 [Liang-Chi Hsieh] Use nullSafeCodeGen. 8092a68 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 696d451 [Liang-Chi Hsieh] Check expression foldable. 48e5536 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen aea58e0 [Liang-Chi Hsieh] For comments. 46d946f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen a0fb76e [Liang-Chi Hsieh] For comments. 6cffe3c [Liang-Chi Hsieh] For comments. 69f0fb6 [Liang-Chi Hsieh] Add code generation for Like and RLike. --- .../expressions/stringOperations.scala | 105 ++++++++++++++---- .../spark/sql/catalyst/util/StringUtils.scala | 47 ++++++++ .../expressions/StringExpressionsSuite.scala | 16 +++ .../sql/catalyst/util/StringUtilsSuite.scala | 34 ++++++ 4 files changed, 180 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 79c0ca56a8e79..99a62343f138d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,8 +21,11 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -160,32 +163,51 @@ trait StringRegexExpression extends ImplicitCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression with CodegenFallback { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - override def escape(v: String): String = - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => - c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) - } - }.mkString - } else { - v - } + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } } @@ -195,6 +217,45 @@ case class RLike(left: Expression, right: Expression) override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala new file mode 100644 index 0000000000000..9ddfb3a0d3759 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.regex.Pattern + +object StringUtils { + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + def escapeLikeRegex(v: String): String = { + if (!v.isEmpty) { + "(?s)" + (' ' +: v.init).zip(v).flatMap { + case (prev, '\\') => "" + case ('\\', c) => + c match { + case '_' => "_" + case '%' => "%" + case _ => Pattern.quote("\\" + c) + } + case (prev, c) => + c match { + case '_' => "." + case '%' => ".*" + case _ => Pattern.quote(Character.toString(c)) + } + }.mkString + } else { + v + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 07b952531ec2e..3ecd0d374c46b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -191,6 +191,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -232,6 +241,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) + checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala new file mode 100644 index 0000000000000..d6f273f9e568a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.StringUtils._ + +class StringUtilsSuite extends SparkFunSuite { + + test("escapeLikeRegex") { + assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") + assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") + assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") + } +} From a3a85d73da053c8e2830759fbc68b734081fa4f3 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 30 Jul 2015 23:50:06 -0700 Subject: [PATCH 05/43] [SPARK-9496][SQL]do not print the password in config https://issues.apache.org/jira/browse/SPARK-9496 We better do not print the password in log. Author: WangTaoTheTonic Closes #7815 from WangTaoTheTonic/master and squashes the following commits: c7a5145 [WangTaoTheTonic] do not print the password in config --- .../org/apache/spark/sql/hive/client/ClientWrapper.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..6e0912da5862d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -91,7 +91,11 @@ private[hive] class ClientWrapper( // this action explicit. initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => - logDebug(s"Hive Config: $k=$v") + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } initialConf.set(k, v) } val newState = new SessionState(initialConf) From 6bba7509a932aa4d39266df2d15b1370b7aabbec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 08:28:05 -0700 Subject: [PATCH 06/43] [SPARK-9500] add TernaryExpression to simplify ternary expressions There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code. cc chenghao-intel Author: Davies Liu Closes #7816 from davies/ternary and squashes the following commits: ed2bf76 [Davies Liu] add TernaryExpression --- .../sql/catalyst/expressions/Expression.scala | 85 +++++ .../expressions/codegen/CodeGenerator.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 66 +--- .../expressions/stringOperations.scala | 356 +++++------------- 4 files changed, 183 insertions(+), 326 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8fc182607ce68..2842b3ec5a0c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { private[sql] object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } + +/** + * An expression with three inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. + */ +abstract class TernaryExpression extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of BinaryExpression. + * If subclass of BinaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val exprs = children + val value1 = exprs(0).eval(input) + if (value1 != null) { + val value2 = exprs(1).eval(input) + if (value2 != null) { + val value3 = exprs(2).eval(input) + if (value3 != null) { + return nullSafeEval(value1, value2, value3) + } + } + } + null + } + + /** + * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = + sys.error(s"BinaryExpressions must override either eval or nullSafeEval") + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { + s"${ev.primitive} = ${f(eval1, eval2, eval3)};" + }) + } + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f function that accepts the 2 non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val evals = children.map(_.gen(ctx)) + val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive) + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } + } + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60e2863f7bbb0..e50ec27fc2eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - val msg = "failed to compile:\n " + CodeFormatter.format(code) + val msg = s"failed to compile: $e\n" + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e6d807f6d897b..15ceb9193a8c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes { - - override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable - - override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = StringType - /** Returns the result of evaluating this expression on a given input Row */ - override def eval(input: InternalRow): Any = { - val num = numExpr.eval(input) - if (num != null) { - val fromBase = fromBaseExpr.eval(input) - if (fromBase != null) { - val toBase = toBaseExpr.eval(input) - if (toBase != null) { - NumberConverter.convert( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numGen = numExpr.gen(ctx) - val from = fromBaseExpr.gen(ctx) - val to = toBaseExpr.gen(ctx) - val numconv = NumberConverter.getClass.getName.stripSuffix("$") - s""" - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${numGen.code} - boolean ${ev.isNull} = ${numGen.isNull}; - if (!${ev.isNull}) { - ${from.code} - if (!${from.isNull}) { - ${to.code} - if (!${to.isNull}) { - ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), - ${from.primitive}, ${to.primitive}); - if (${ev.primitive} == null) { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } + nullSafeCodeGen(ctx, ev, (num, from, to) => + s""" + ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to); + if (${ev.primitive} == null) { + ${ev.isNull} = true; } - """ + """ + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 99a62343f138d..684eac12bd6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.lpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } override def prettyName: String = "lpad" @@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.rpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } override def prettyName: String = "rpad" @@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) } - override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - override def eval(input: InternalRow): Any = { - val stringEval = str.eval(input) - if (stringEval != null) { - val posEval = pos.eval(input) - if (posEval != null) { - val lenEval = len.eval(input) - if (lenEval != null) { - stringEval.asInstanceOf[UTF8String] - .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { + string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val strGen = str.gen(ctx) - val posGen = pos.gen(ctx) - val lenGen = len.gen(ctx) - - val start = ctx.freshName("start") - val end = ctx.freshName("end") - - s""" - ${strGen.code} - boolean ${ev.isNull} = ${strGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${posGen.code} - if (!${posGen.isNull}) { - ${lenGen.code} - if (!${lenGen.isNull}) { - ${ev.primitive} = ${strGen.primitive} - .substringSQL(${posGen.primitive}, ${lenGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)") } } @@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression) * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // result buffer write by Matcher @transient private val result: StringBuffer = new StringBuffer - override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = rep.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) - return UTF8String.fromString(result.toString) - } - } + while (m.find) { + m.appendReplacement(result, lastReplacement) } + m.appendTail(result) - null + UTF8String.fromString(result.toString) } override def dataType: DataType = StringType @@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termResult = ctx.freshName("result") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - val classNameString = classOf[java.lang.String].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState(classNameString, + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalRep = rep.gen(ctx) - + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - ${evalSubject.code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalRep.code} - if (!${evalRep.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = ${evalRep.primitive}; - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); - ${ev.isNull} = false; - } - } + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; """ + }) } } @@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. @@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ - override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = idx.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } - return UTF8String.EMPTY_UTF8 - } - } + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 } - - null } override def dataType: DataType = StringType @@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalIdx = idx.gen(ctx) - - s""" - ${evalSubject.code} - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - boolean ${ev.isNull} = true; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalIdx.code} - if (!${evalIdx.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); - if (m.find()) { - ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); - ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); - ${ev.isNull} = false; - } else { - ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; - ${ev.isNull} = false; - } - } - } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } - """ + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) } } From fc0e57e5aba82a3f227fef05a843283e2ec893fc Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 31 Jul 2015 09:33:38 -0700 Subject: [PATCH 07/43] [SPARK-9053] [SPARKR] Fix spaces around parens, infix operators etc. ### JIRA [[SPARK-9053] Fix spaces around parens, infix operators etc. - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9053) ### The Result of `lint-r` [The result of lint-r at the rivision:a4c83cb1e4b066cd60264b6572fd3e51d160d26a](https://gist.github.com/yu-iskw/d253d7f8ef351f86443d) Author: Yu ISHIKAWA Closes #7584 from yu-iskw/SPARK-9053 and squashes the following commits: 613170f [Yu ISHIKAWA] Ignore a warning about a space before a left parentheses ede61e1 [Yu ISHIKAWA] Ignores two warnings about a space before a left parentheses. TODO: After updating `lintr`, we will remove the ignores de3e0db [Yu ISHIKAWA] Add '## nolint start' & '## nolint end' statement to ignore infix space warnings e233ea8 [Yu ISHIKAWA] [SPARK-9053][SparkR] Fix spaces around parens, infix operators etc. --- R/pkg/R/DataFrame.R | 4 ++++ R/pkg/R/RDD.R | 7 +++++-- R/pkg/R/column.R | 2 +- R/pkg/R/context.R | 2 +- R/pkg/R/pairRDD.R | 2 +- R/pkg/R/utils.R | 4 ++-- R/pkg/inst/tests/test_binary_function.R | 2 +- R/pkg/inst/tests/test_rdd.R | 6 +++--- R/pkg/inst/tests/test_sparkSQL.R | 4 +++- 9 files changed, 21 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c93d3c7dd67..b31ad3729e09b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1322,9 +1322,11 @@ setMethod("write.df", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { @@ -1384,9 +1386,11 @@ setMethod("saveAsTable", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d2d096709245d..2a013b3dbb968 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) isPipelinable <- function(rdd) { e <- rdd@env + # nolint start !(e$isCached || e$isCheckpointed) + # nolint end } if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { @@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { pipelinedFunc <- function(partIndex, part) { - func(partIndex, prev@func(partIndex, part)) + f <- prev@func + func(partIndex, f(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -841,7 +844,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- rpois(1, fraction) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 2892e1416cc65..eeaf9f193b728 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -65,7 +65,7 @@ functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", "expm1", "floor", "log", "log10", "log1p", "rint", "sign", "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions<- c("atan2", "hypot") +binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 43be9c904fdf6..720990e1c6087 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) { numSlices <- length(coll) sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 83801d3209700..199c3fd6ab1b2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -879,7 +879,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- rpois(1, frac) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3f45589a50443..4f9f4d9cad2a8 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, } results <- if (arrSize > 0) { - lapply(0:(arrSize - 1), + lapply(0 : (arrSize - 1), function(index) { obj <- callJMethod(jList, "get", as.integer(index)) @@ -572,7 +572,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[ (lengthOfKeys + 1) : (len - 1) ] } else { values <- list() } diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index dca0657c57e0d..f054ac9a87d61 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -40,7 +40,7 @@ test_that("union on two RDDs", { expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") - rdd<- map(text.rdd, function(x) {x}) + rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 6c3aaab8c711e..71aed2bb9d6a8 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", { expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", { }) test_that("keyBy on RDDs", { - func <- function(x) { x*x } + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collect(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) @@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", { r2 <- repartition(rdd, 6) expect_equal(numPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) - expect_true(count >=0 && count <= 4) + expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 61c8a7ec7d837..aca41aa6dcf24 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -666,10 +666,12 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end }) test_that("string operators", { @@ -876,7 +878,7 @@ test_that("parquetFile works with multiple input paths", { write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df)*2) + expect_equal(count(parquetDF), count(df) * 2) }) test_that("describe() on a DataFrame", { From 04a49edfdb606c01fa4f8ae6e730ec4f9bd0cb6d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 09:34:10 -0700 Subject: [PATCH 08/43] [SPARK-9497] [SPARK-9509] [CORE] Use ask instead of askWithRetry `RpcEndpointRef.askWithRetry` throws `SparkException` rather than `TimeoutException`. Use ask to replace it because we don't need to retry here. Author: zsxwing Closes #7824 from zsxwing/SPARK-9497 and squashes the following commits: 7bfc2b4 [zsxwing] Use ask instead of askWithRetry --- .../scala/org/apache/spark/deploy/client/AppClient.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 79b251e7e62fe..a659abf70395d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -248,7 +248,8 @@ private[spark] class AppClient( def stop() { if (endpoint != null) { try { - endpoint.askWithRetry[Boolean](StopAppClient) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") From 27ae851ce16082775ffbcb5b8fc6bdbe65dc70fc Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 31 Jul 2015 18:16:55 +0100 Subject: [PATCH 09/43] [SPARK-9446] Clear Active SparkContext in stop() method In thread 'stopped SparkContext remaining active' on mailing list, Andres observed the following in driver log: ``` 15/07/29 15:17:09 WARN YarnSchedulerBackend$YarnSchedulerEndpoint: ApplicationMaster has disassociated:
15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Shutting down all executors Exception in thread "Yarn application state monitor" org.apache.spark.SparkException: Error asking standalone scheduler to shut down executors at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:261) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stop(CoarseGrainedSchedulerBackend.scala:266) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend.stop(YarnClientSchedulerBackend.scala:158) at org.apache.spark.scheduler.TaskSchedulerImpl.stop(TaskSchedulerImpl.scala:416) at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:1411) at org.apache.spark.SparkContext.stop(SparkContext.scala:1644) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend$$anon$1.run(YarnClientSchedulerBackend.scala:139) Caused by: java.lang.InterruptedException at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1325) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:190) at scala.concurrent.BlockContext$DefaultBlockContext$.blockOn(BlockContext.scala:53) at scala.concurrent.Await$.result(package.scala:190)15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Asking each executor to shut down at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:102) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:78) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:257) ... 6 more ``` Effect of the above exception is that a stopped SparkContext is returned to user since SparkContext.clearActiveContext() is not called. Author: tedyu Closes #7756 from tedyu/master and squashes the following commits: 7339ff2 [tedyu] Move null assignment out of tryLogNonFatalError block 6e02cd9 [tedyu] Use Utils.tryLogNonFatalError to guard resource release f5fb519 [tedyu] Clear Active SparkContext in stop() method using finally --- .../scala/org/apache/spark/SparkContext.scala | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ac6ac6c216767..2d8aa25d81daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1689,33 +1689,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.removeShutdownHook(_shutdownHookRef) } - postApplicationEnd() - _ui.foreach(_.stop()) + Utils.tryLogNonFatalError { + postApplicationEnd() + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } if (env != null) { - env.metricsSystem.report() + Utils.tryLogNonFatalError { + env.metricsSystem.report() + } } if (metadataCleaner != null) { - metadataCleaner.cancel() + Utils.tryLogNonFatalError { + metadataCleaner.cancel() + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) } - _cleaner.foreach(_.stop()) - _executorAllocationManager.foreach(_.stop()) if (_dagScheduler != null) { - _dagScheduler.stop() + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } _dagScheduler = null } if (_listenerBusStarted) { - listenerBus.stop() - _listenerBusStarted = false + Utils.tryLogNonFatalError { + listenerBus.stop() + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) } - _eventLogger.foreach(_.stop()) if (env != null && _heartbeatReceiver != null) { - env.rpcEnv.stop(_heartbeatReceiver) + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) } - _progressBar.foreach(_.stop()) _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { - _env.stop() + Utils.tryLogNonFatalError { + _env.stop() + } SparkEnv.set(null) } SparkContext.clearActiveContext() From 0024da9157ba12ec84883a78441fa6835c1d0042 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 11:07:34 -0700 Subject: [PATCH 10/43] [SQL] address comments for to_date/trunc This PR address the comments in #7805 cc rxin Author: Davies Liu Closes #7817 from davies/trunc and squashes the following commits: f729d5f [Davies Liu] rollback cb7f7832 [Davies Liu] genCode() is protected 31e52ef [Davies Liu] fix style ed1edc7 [Davies Liu] address comments for #7805 --- .../catalyst/expressions/datetimeFunctions.scala | 15 ++++++++------- .../spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- .../expressions/ExpressionEvalHelper.scala | 4 +--- .../scala/org/apache/spark/sql/functions.scala | 3 +++ 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 6e7613340c032..07dea5b470b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -726,15 +726,16 @@ case class TruncDate(date: Expression, format: Expression) override def dataType: DataType = DateType override def prettyName: String = "trunc" - lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + private lazy val truncLevel: Int = + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) override def eval(input: InternalRow): Any = { - val minItem = if (format.foldable) { - minItemConst + val level = if (format.foldable) { + truncLevel } else { DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } - if (minItem == -1) { + if (level == -1) { // unknown format null } else { @@ -742,7 +743,7 @@ case class TruncDate(date: Expression, format: Expression) if (d == null) { null } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) } } } @@ -751,7 +752,7 @@ case class TruncDate(date: Expression, format: Expression) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { - if (minItemConst == -1) { + if (truncLevel == -1) { s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -763,7 +764,7 @@ case class TruncDate(date: Expression, format: Expression) boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 5a7c25b8d508d..032ed8a56a50e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -794,7 +794,8 @@ object DateTimeUtils { } else if (level == TRUNC_TO_MONTH) { d - DateTimeUtils.getDayOfMonth(d) + 1 } else { - throw new Exception(s"Invalid trunc level: $level") + // caller make sure that this should never be reached + sys.error(s"Invalid trunc level: $level") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3c05e5c3b833c..a41185b4d8754 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 46dc4605a5ccb..5d82a5eadd94d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2192,6 +2192,9 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * or 'month', 'mon', 'mm' for truncate by month + * * @group datetime_funcs * @since 1.5.0 */ From 6add4eddb39e7748a87da3e921ea3c7881d30a82 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Fri, 31 Jul 2015 11:22:40 -0700 Subject: [PATCH 11/43] [SPARK-9471] [ML] Multilayer Perceptron This pull request contains the following feature for ML: - Multilayer Perceptron classifier This implementation is based on our initial pull request with bgreeven: https://github.com/apache/spark/pull/1290 and inspired by very insightful suggestions from mengxr and witgo (I would like to thank all other people from the mentioned thread for useful discussions). The original code was extensively tested and benchmarked. Since then, I've addressed two main requirements that prevented the code from merging into the main branch: - Extensible interface, so it will be easy to implement new types of networks - Main building blocks are traits `Layer` and `LayerModel`. They are used for constructing layers of ANN. New layers can be added by extending the `Layer` and `LayerModel` traits. These traits are private in this release in order to save path to improve them based on community feedback - Back propagation is implemented in general form, so there is no need to change it (optimization algorithm) when new layers are implemented - Speed and scalability: this implementation has to be comparable in terms of speed to the state of the art single node implementations. - The developed benchmark for large ANN shows that the proposed code is on par with C++ CPU implementation and scales nicely with the number of workers. Details can be found here: https://github.com/avulanov/ann-benchmark - DBN and RBM by witgo https://github.com/witgo/spark/tree/ann-interface-gemm-dbn - Dropout https://github.com/avulanov/spark/tree/ann-interface-gemm mengxr and dbtsai kindly agreed to perform code review. Author: Alexander Ulanov Author: Bert Greevenbosch Closes #7621 from avulanov/SPARK-2352-ann and squashes the following commits: 4806b6f [Alexander Ulanov] Addressing reviewers comments. a7e7951 [Alexander Ulanov] Default blockSize: 100. Added documentation to blockSize parameter and DataStacker class f69bb3d [Alexander Ulanov] Addressing reviewers comments. 374bea6 [Alexander Ulanov] Moving ANN to ML package. GradientDescent constructor is now spark private. 43b0ae2 [Alexander Ulanov] Addressing reviewers comments. Adding multiclass test. 9d18469 [Alexander Ulanov] Addressing reviewers comments: unnecessary copy of data in predict 35125ab [Alexander Ulanov] Style fix in tests e191301 [Alexander Ulanov] Apache header a226133 [Alexander Ulanov] Multilayer Perceptron regressor and classifier --- .../org/apache/spark/ml/ann/BreezeUtil.scala | 63 ++ .../scala/org/apache/spark/ml/ann/Layer.scala | 882 ++++++++++++++++++ .../MultilayerPerceptronClassifier.scala | 193 ++++ .../org/apache/spark/ml/param/params.scala | 5 + .../mllib/optimization/GradientDescent.scala | 2 +- .../org/apache/spark/ml/ann/ANNSuite.scala | 91 ++ .../MultilayerPerceptronClassifierSuite.scala | 91 ++ 7 files changed, 1326 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala new file mode 100644 index 0000000000000..7429f9d652ac5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +private[ann] object BreezeUtil { + + // TODO: switch to MLlib BLAS interface + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(a.cols == x.length, "A & b Dimension mismatch!") + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala new file mode 100644 index 0000000000000..b5258ff348477 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum(target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { + val size = 0 + // matrices for in-place computations + // outputs + private var f: BDM[Double] = null + // delta + private var d: BDM[Double] = null + // matrix for error computation + private var e: BDM[Double] = null + // delta gradient + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +private[ml] object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for (i <- 1 until layerModels.length) { + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for (i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for (i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map { v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + ) } + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * MLlib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 128 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 0000000000000..8cd2103d7d5e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams + with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.arrayLengthGt(1) + ) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) +} + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return pair of features and vector encoding of a label + */ + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def decodeLabel(output: Vector): Double = { + output.argmax.toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier(override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + with MultilayerPerceptronParams { + + def this() = this(Identifiable.randomUID("mlpc")) + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + val myLayers = $(layers) + val labels = myLayers.last + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classifier model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model that consists of the weights of layers + * @return prediction model + */ +@Experimental +class MultilayerPerceptronClassifierModel private[ml] ( + override val uid: String, + layers: Array[Int], + weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter.decodeLabel(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { + copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 954aa17e26a02..d68f5ff0053c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -166,6 +166,11 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Check that the array length is greater than lowerBound. */ + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index ab7611fd077ef..8f0d1e4aa010a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala new file mode 100644 index 0000000000000..1292e57d7c01a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + + +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + + // TODO: test for weights comparison with Weka MLP + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array(0.0, 1.0, 1.0, 0.0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array( + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input), label) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 0000000000000..ddc948f65df45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) + } + } + + // TODO: implement a more rigorous test + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + } +} From 4011a947154d97a9ffb5a71f077481a12534d36b Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 31 Jul 2015 11:50:15 -0700 Subject: [PATCH 12/43] [SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel --- .../spark/mllib/clustering/LDAModel.scala | 19 ++++++++++++++++++- .../spark/mllib/clustering/LDASuite.scala | 13 ++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cfad3fbbdb87..82281a0daf008 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -591,6 +591,23 @@ class DistributedLDAModel private[clustering] ( JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } + /** + * For each document, return the top k weighted topics for that document and their weights. + * @return RDD of (doc ID, topic indices, topic weights) + */ + def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + val topIndices = argtopk(topicCounts, k) + val sumCounts = sum(topicCounts) + val weights = if (sumCounts != 0) { + topicCounts(topIndices) / sumCounts + } else { + topicCounts(topIndices) + } + (docID.toLong, topIndices.toArray, weights.toArray) + } + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index c43e1e575c09c..695ee3b82efc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, max, argmax} +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -108,6 +108,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) } + val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3))) + model.topicDistributions.join(top2TopicsPerDoc).collect().foreach { + case (docId, (topicDistribution, (indices, weights))) => + assert(indices.length == 2) + assert(weights.length == 2) + val bdvTopicDist = topicDistribution.toBreeze + val top2Indices = argtopk(bdvTopicDist, 2) + assert(top2Indices.toArray === indices) + assert(bdvTopicDist(top2Indices).toArray === weights) + } + // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) From e8bdcdeabb2df139a656f86686cdb53c891b1f4b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 11:56:52 -0700 Subject: [PATCH 13/43] [SPARK-6885] [ML] decision tree support predict class probabilities Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities --- .../DecisionTreeClassifier.scala | 40 ++++-- .../ml/classification/GBTClassifier.scala | 2 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../scala/org/apache/spark/ml/tree/Node.scala | 80 ++++++----- .../spark/ml/tree/impl/RandomForest.scala | 126 ++++++++---------- .../spark/mllib/tree/impurity/Entropy.scala | 2 +- .../spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../tree/model/InformationGainStats.scala | 61 ++++++++- .../DecisionTreeClassifierSuite.scala | 30 ++++- .../classification/GBTClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- 16 files changed, 229 insertions(+), 130 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 36fe1bd40469c..f27cfd0331419 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,12 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame */ @Experimental final class DecisionTreeClassifier(override val uid: String) - extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("dtc")) @@ -106,8 +105,9 @@ object DecisionTreeClassifier { @Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, - override val rootNode: Node) - extends PredictionModel[Vector, DecisionTreeClassificationModel] + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { require(rootNode != null, @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val sum = dv.values.sum + while (i < size) { + dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) } override def toString: String = { @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel { s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode) + new DecisionTreeClassificationModel(uid, rootNode, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index eb0b1a0a405fc..c3891a9599262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,7 +190,7 @@ final class GBTClassificationModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc19bd6df894f..0c7eb4a662fdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] ( // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) _trees.view.foreach { tree => - val prediction = tree.rootNode.predict(features).toInt + val prediction = tree.rootNode.predictImpl(features).prediction.toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6f3340c2f02be..4d30e4b5548aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] ( def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index e38dc73ee0ba7..5633bc320273a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -180,7 +180,7 @@ final class GBTRegressionModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 506a878c2553b..17fb1ad5e15d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index bbc2427ca7d3d..8879352a600a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} + Node => OldNode, Predict => OldPredict, ImpurityStats} /** * :: DeveloperApi :: @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[tree] def impurityStats: ImpurityCalculator + /** Recursive prediction helper method */ - private[ml] def predict(features: Vector): Double = prediction + private[ml] def predictImpl(features: Vector): LeafNode /** * Get the number of nodes in tree below this node, including leaf nodes. @@ -75,7 +83,8 @@ private[ml] object Node { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain @@ -85,7 +94,7 @@ private[ml] object Node { new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } @@ -99,11 +108,13 @@ private[ml] object Node { @DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, - override val impurity: Double) extends Node { + override val impurity: Double, + override val impurityStats: ImpurityCalculator) extends Node { - override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" - override private[ml] def predict(features: Vector): Double = prediction + override private[ml] def predictImpl(features: Vector): LeafNode = this override private[tree] def numDescendants: Int = 0 @@ -115,9 +126,8 @@ final class LeafNode private[ml] ( override private[tree] def subtreeDepth: Int = 0 override private[ml] def toOld(id: Int): OldNode = { - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, - None, None, None, None) + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) } } @@ -139,17 +149,18 @@ final class InternalNode private[ml] ( val gain: Double, val leftChild: Node, val rightChild: Node, - val split: Split) extends Node { + val split: Split, + override val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } - override private[ml] def predict(features: Vector): Double = { + override private[ml] def predictImpl(features: Vector): LeafNode = { if (split.shouldGoLeft(features)) { - leftChild.predict(features) + leftChild.predictImpl(features) } else { - rightChild.predict(features) + rightChild.predictImpl(features) } } @@ -172,9 +183,8 @@ final class InternalNode private[ml] ( override private[ml] def toOld(id: Int): OldNode = { assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + " since the old API does not support deep trees.") - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, - Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, new OldPredict(leftChild.prediction, prob = 0.0), @@ -223,36 +233,36 @@ private object InternalNode { * * @param id We currently use the same indexing as the old implementation in * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. - * @param predictionStats Predicted label + class probability (for classification). - * We will later modify this to store aggregate statistics for labels - * to provide all class probabilities (for classification) and maybe a - * distribution (for regression). * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * so that we do not need to consider splitting it further. - * @param stats Old structure for storing stats about information gain, prediction, etc. - * This is legacy and will be modified in the future. + * @param stats Impurity statistics for this node. */ private[tree] class LearningNode( var id: Int, - var predictionStats: OldPredict, - var impurity: Double, var leftChild: Option[LearningNode], var rightChild: Option[LearningNode], var split: Option[Split], var isLeaf: Boolean, - var stats: Option[OldInformationGainStats]) extends Serializable { + var stats: ImpurityStats) extends Serializable { /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ def toNode: Node = { if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(predictionStats.predict, impurity, stats.get.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get) + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - new LeafNode(predictionStats.predict, impurity) + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + } } @@ -263,16 +273,14 @@ private[tree] object LearningNode { /** Create a node with some of its fields set. */ def apply( id: Int, - predictionStats: OldPredict, - impurity: Double, - isLeaf: Boolean): LearningNode = { - new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) } /** Create an empty node with the given node index. Values must be set later on. */ def emptyNode(nodeIndex: Int): LearningNode = { - new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, - None, None, None, false, None) + new LearningNode(nodeIndex, None, None, None, false, null) } // The below indexing methods were copied from spark.mllib.tree.model.Node diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 15b56bd844bad..a8b90d9d266a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging { parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) } case None => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) } @@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging { } // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) + (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") @@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging { val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.predictionStats = predict node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity + node.stats = stats logDebug("Node = " + node) if (!isLeaf) { @@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging { val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( @@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) */ - private def calculateGainForSplit( + private def calculateImpurityStats( + stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count + val totalCount = leftCount + rightCount + // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - val totalCount = leftCount + rightCount - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging { // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, InformationGainStats, Predict) = { + node: LearningNode): (Split, ImpurityStats) = { - // Calculate prediction and impurity if current node is top node + // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { - None + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null } else { - Some((node.predictionStats, node.impurity)) + node.stats } // For each (feature, split), calculate the gain, and select the best (feature, split). @@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIdx, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { @@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { @@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) @@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging { } }.maxBy(_._2.gain) - (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + (bestSplit, bestSplitStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5ac10f3fd32dd..0768204c33914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 19d318203c344..d0077db6832e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 578749d85a4e6..86cee7e430b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7104a7fa4dd4c..04d0cd24e6632 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -98,7 +98,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dc9e0f9f51ffb..508bf9c1bdb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -66,7 +67,6 @@ class InformationGainStats( } } - private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to @@ -76,3 +76,62 @@ private[spark] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 73b4805c4c597..c7bbf1ce07a23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) ParamsSuite.checkParams(model) } @@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index a7bc77965fefd..d4b5896c12c06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), Array(1.0)) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ab711c8e4b215..dbb2577c6204d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) ParamsSuite.checkParams(model) } From 0a1d2ca42c8b31d6b0e70163795f0185d4622f87 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 31 Jul 2015 12:04:03 -0700 Subject: [PATCH 14/43] [SPARK-8979] Add a PID based rate estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on #7600 /cc tdas Author: Iulian Dragos Author: François Garillot Closes #7648 from dragos/topic/streaming-bp/pid and squashes the following commits: aa5b097 [Iulian Dragos] Add more comments, made all PID constant parameters positive, a couple more tests. 93b74f8 [Iulian Dragos] Better explanation of historicalError. 7975b0c [Iulian Dragos] Add configuration for PID. 26cfd78 [Iulian Dragos] A couple of variable renames. d0bdf7c [Iulian Dragos] Update to latest version of the code, various style and name improvements. d58b845 [François Garillot] [SPARK-8979][Streaming] Implements a PIDRateEstimator --- .../dstream/ReceiverInputDStream.scala | 2 +- .../scheduler/rate/PIDRateEstimator.scala | 124 ++++++++++++++++ .../scheduler/rate/RateEstimator.scala | 18 ++- .../rate/PIDRateEstimatorSuite.scala | 137 ++++++++++++++++++ 4 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 646a8c3530a62..670ef8d296a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ override protected[streaming] val rateController: Option[RateController] = { if (RateController.isBackPressureEnabled(ssc.conf)) { - RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration))) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 0000000000000..6ae56a68ad88c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. A PID controller works + * by calculating an '''error''' between a measured output and a desired value. In the + * case of Spark Streaming the error is the difference between the measured processing + * rate (number of elements/processing delay) and the previous rate. + * + * @see https://en.wikipedia.org/wiki/PID_controller + * + * @param batchDurationMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error. This term usually provides the bulk of correction and should be positive or zero. + * A value too large would make the controller overshoot the setpoint, while a small value + * would make the controller too insensitive. The default value is 1. + * @param integral how much the correction should depend on the accumulation + * of past errors. This value should be positive or 0. This term accelerates the movement + * towards the desired value, but a large value may lead to overshooting. The default value + * is 0.2. + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change. This value should be positive or 0. + * This term is not used very often, as it impacts stability of the system. The default + * value is 0. + */ +private[streaming] class PIDRateEstimator( + batchIntervalMillis: Long, + proportional: Double = 1D, + integral: Double = .2D, + derivative: Double = 0D) + extends RateEstimator { + + private var firstRun: Boolean = true + private var latestTime: Long = -1L + private var latestRate: Double = -1D + private var latestError: Double = -1L + + require( + batchIntervalMillis > 0, + s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.") + require( + proportional >= 0, + s"Proportional term $proportional in PIDRateEstimator should be >= 0.") + require( + integral >= 0, + s"Integral term $integral in PIDRateEstimator should be >= 0.") + require( + derivative >= 0, + s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + + + def compute(time: Long, // in milliseconds + numElements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + + this.synchronized { + if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingRate = numElements.toDouble / processingDelay * 1000 + + // In our system `error` is the difference between the desired rate and the measured rate + // based on the latest batch information. We consider the desired rate to be latest rate, + // which is what this estimator calculated for the previous batch. + // in elements/second + val error = latestRate - processingRate + + // The error integral, based on schedulingDelay as an indicator for accumulated errors. + // A scheduling delay s corresponds to s * processingRate overflowing elements. Those + // are elements that couldn't be processed in previous batches, leading to this delay. + // In the following, we assume the processingRate didn't change too much. + // From the number of overflowing elements we can calculate the rate at which they would be + // processed by dividing it by the batch interval. This rate is our "historical" error, + // or integral part, since if we subtracted this rate from the previous "calculated rate", + // there wouldn't have been any overflowing elements, and the scheduling delay would have + // been zero. + // (in elements/second) + val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newRate = (latestRate - proportional * error - + integral * historicalError - + derivative * dError).max(0.0) + latestTime = time + if (firstRun) { + latestRate = processingRate + latestError = 0D + firstRun = false + + None + } else { + latestRate = newRate + latestError = error + + Some(newRate) + } + } else None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index a08685119e5d5..17ccebc1ed41b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf import org.apache.spark.SparkException +import org.apache.spark.streaming.Duration /** * A component that estimates the rate at wich an InputDStream should ingest @@ -48,12 +49,21 @@ object RateEstimator { /** * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. * - * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * The only known estimator right now is `pid`. + * + * @return An instance of RateEstimator * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any * known estimators. */ - def create(conf: SparkConf): Option[RateEstimator] = - conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + def create(conf: SparkConf, batchInterval: Duration): RateEstimator = + conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { + case "pid" => + val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) + val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) + val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + + case estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 0000000000000..97c32d8f2d59e --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import scala.util.Random + +import org.scalatest.Inspectors.forAll +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.Seconds + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("the right estimator is created") { + val conf = new SparkConf + conf.set("spark.streaming.backpressure.rateEstimator", "pid") + val pid = RateEstimator.create(conf, Seconds(1)) + pid.getClass should equal(classOf[PIDRateEstimator]) + } + + test("estimator checks ranges") { + intercept[IllegalArgumentException] { + new PIDRateEstimator(0, 1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, -1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, -1, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, -1) + } + } + + private def createDefaultEstimator: PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D) + } + + test("first bound is None") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) should equal(None) + } + + test("second bound is rate") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal(Some(1000)) + } + + test("works even with no time between updates") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + p.compute(10, 10, 10, 0) + p.compute(10, 10, 10, 0) should equal(None) + } + + test("bound is never negative") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing + // this might point the estimator to try and decrease the bound, but we test it never + // goes below zero, which would be nonsensical. + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(0) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.fill(49)(Some(0D))) + } + + test("with no accumulated or positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an increasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => x * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term, + // asking for less and less elements + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { + val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + forAll(List.range(1, 50)) { (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + } + } + } +} From 39ab199a3f735b7658ab3331d3e2fb03441aec13 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 31 Jul 2015 12:07:18 -0700 Subject: [PATCH 15/43] [SPARK-8640] [SQL] Enable Processing of Multiple Window Frames in a Single Window Operator This PR enables the processing of multiple window frames in a single window operator. This should improve the performance of processing multiple window expressions wich share partition by/order by clauses, because it will be more efficient with respect to memory use and group processing. Author: Herman van Hovell Closes #7515 from hvanhovell/SPARK-8640 and squashes the following commits: f0e1c21 [Herman van Hovell] Changed Window Logical/Physical plans to use partition by/order by specs directly instead of using WindowSpec. e1711c2 [Herman van Hovell] Enabled the processing of multiple window frames in a single Window operator. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +++++++----- .../plans/logical/basicOperators.scala | 3 ++- .../spark/sql/execution/SparkStrategies.scala | 5 +++-- .../apache/spark/sql/execution/Window.scala | 19 ++++++++++--------- .../sql/hive/execution/HivePlanTest.scala | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 265f3d1e41765..51d910b258647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -347,7 +347,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, child) + case oldVersion @ Window(_, windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -825,7 +825,7 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - // Second, we group extractedWindowExprBuffer based on their Window Spec. + // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec @@ -841,7 +841,8 @@ class Analyzer( failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + s"Please file a bug report with this error message, stack trace, and the query.") } else { - distinctWindowSpec.head + val spec = distinctWindowSpec.head + (spec.partitionSpec, spec.orderSpec) } }.toSeq @@ -850,9 +851,10 @@ class Analyzer( var currentChild = child var i = 0 while (i < groupedWindowExpressions.size) { - val (windowSpec, windowExpressions) = groupedWindowExpressions(i) + val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. - currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) + currentChild = Window(currentChild.output, windowExpressions, + partitionSpec, orderSpec, currentChild) // Move to next Window Spec. i += 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a67f8de6b733a..aacfc86ab0e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,8 @@ case class Aggregate( case class Window( projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 03d24a88d4ecd..4aff52d992e6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -389,8 +389,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } - case logical.Window(projectList, windowExpressions, spec, child) => - execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil + case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + execution.Window( + projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 91c8a02e2b5bc..fe9f2c7028171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -80,23 +80,24 @@ import scala.collection.mutable case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { - if (windowSpec.partitionSpec.isEmpty) { + if (partitionSpec.isEmpty) { // Only show warning when the number of bytes is larger than 100 MB? logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -115,12 +116,12 @@ case class Window( case RangeFrame => val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. - val exprs = windowSpec.orderSpec.map(_.child) + val exprs = orderSpec.map(_.child) val projection = newMutableProjection(exprs, child.output) - (windowSpec.orderSpec, projection(), projection()) - } else if (windowSpec.orderSpec.size == 1) { + (orderSpec, projection(), projection()) + } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. - val sortExpr = windowSpec.orderSpec.head + val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() @@ -250,7 +251,7 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(windowSpec.partitionSpec, child.output) + val grouping = newProjection(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index bdb53ddf59c19..ba56a8a6b689c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { @@ -31,4 +34,19 @@ class HivePlanTest extends QueryTest { comparePlans(optimized, correctAnswer) } + + test("window expressions sharing the same partition by and order by clause") { + val df = Seq.empty[(Int, String, Int, Int)].toDF("id", "grp", "seq", "val") + val window = Window. + partitionBy($"grp"). + orderBy($"val") + val query = df.select( + $"id", + sum($"val").over(window.rowsBetween(-1, 1)), + sum($"val").over(window.rangeBetween(-1, 1)) + ) + val plan = query.queryExecution.analyzed + assert(plan.collect{ case w: logical.Window => w }.size === 1, + "Should have only 1 Window operator.") + } } From 3afc1de89cb4de9f8ea74003dd1e6b5b006d06f0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:09:48 -0700 Subject: [PATCH 16/43] [SPARK-8564] [STREAMING] Add the Python API for Kinesis This PR adds the Python API for Kinesis, including a Python example and a simple unit test. Author: zsxwing Closes #6955 from zsxwing/kinesis-python and squashes the following commits: e42e471 [zsxwing] Merge branch 'master' into kinesis-python 455f7ea [zsxwing] Remove streaming_kinesis_asl_assembly module and simply add the source folder to streaming_kinesis_asl module 32e6451 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 5082d28 [zsxwing] Fix the syntax error for Python 2.6 fca416b [zsxwing] Fix wrong comparison 96670ff [zsxwing] Fix the compilation error after merging master 756a128 [zsxwing] Merge branch 'master' into kinesis-python 6c37395 [zsxwing] Print stack trace for debug 7c5cfb0 [zsxwing] RUN_KINESIS_TESTS -> ENABLE_KINESIS_TESTS cc9d071 [zsxwing] Fix the python test errors 466b425 [zsxwing] Add python tests for Kinesis e33d505 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 3da2601 [zsxwing] Fix the kinesis folder 687446b [zsxwing] Fix the error message and the maven output path add2beb [zsxwing] Merge branch 'master' into kinesis-python 4957c0b [zsxwing] Add the Python API for Kinesis --- dev/run-tests.py | 3 +- dev/sparktestsupport/modules.py | 9 +- docs/streaming-kinesis-integration.md | 19 +++ extras/kinesis-asl-assembly/pom.xml | 103 ++++++++++++++++ .../streaming/kinesis_wordcount_asl.py | 81 +++++++++++++ .../streaming/kinesis/KinesisTestUtils.scala | 19 ++- .../streaming/kinesis/KinesisUtils.scala | 78 +++++++++--- pom.xml | 1 + project/SparkBuild.scala | 6 +- python/pyspark/streaming/kinesis.py | 112 ++++++++++++++++++ python/pyspark/streaming/tests.py | 86 +++++++++++++- 11 files changed, 492 insertions(+), 25 deletions(-) create mode 100644 extras/kinesis-asl-assembly/pom.xml create mode 100644 extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py create mode 100644 python/pyspark/streaming/kinesis.py diff --git a/dev/run-tests.py b/dev/run-tests.py index 29420da9aa956..b6d181418f027 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -301,7 +301,8 @@ def build_spark_sbt(hadoop_version): sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly", - "streaming-flume-assembly/assembly"] + "streaming-flume-assembly/assembly", + "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 44600cb9523c1..956dc81b62e93 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -138,6 +138,7 @@ def contains_file(self, filename): dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", + "extras/kinesis-asl-assembly/", ], build_profile_flags=[ "-Pkinesis-asl", @@ -300,7 +301,13 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], + dependencies=[ + pyspark_core, + streaming, + streaming_kafka, + streaming_flume_assembly, + streaming_kinesis_asl + ], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index aa9749afbc867..a7bcaec6fcd84 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+ from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
@@ -135,6 +146,14 @@ To run the example, bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
+ + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name] +
diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 0000000000000..70d2c9c58f54e --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 0000000000000..f428f64da3c42 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index ca39358b75cb6..255ac27f793ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -36,9 +36,15 @@ import org.apache.spark.Logging /** * Shared utility methods for performing Kinesis tests that actually transfer data */ -private class KinesisTestUtils( - val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", - _regionName: String = "") extends Logging { +private class KinesisTestUtils(val endpointUrl: String, _regionName: String) extends Logging { + + def this() { + this("https://kinesis.us-west-2.amazonaws.com", "") + } + + def this(endpointUrl: String) { + this(endpointUrl, "") + } val regionName = if (_regionName.length == 0) { RegionUtils.getRegionByEndpoint(endpointUrl).getName() @@ -117,6 +123,13 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + } + def deleteStream(): Unit = { try { if (streamCreated) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index e5acab50181e1..7dab17eba8483 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -86,19 +86,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( ssc: StreamingContext, @@ -130,7 +130,7 @@ object KinesisUtils { * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in * [[org.apache.spark.SparkConf]]. * - * @param ssc Java StreamingContext object + * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Endpoint url of Kinesis service * (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -175,15 +175,15 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ @@ -206,8 +206,8 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library @@ -216,19 +216,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( jssc: JavaStreamingContext, @@ -297,3 +297,49 @@ object KinesisUtils { } } } + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + +} diff --git a/pom.xml b/pom.xml index 35fc8c44bc1b0..e351c7c19df96 100644 --- a/pom.xml +++ b/pom.xml @@ -1642,6 +1642,7 @@ kinesis-asl extras/kinesis-asl + extras/kinesis-asl-assembly diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 61a05d375d99e..9a33baa7c6ce1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -382,7 +382,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py new file mode 100644 index 0000000000000..bcfe2703fecf9 --- /dev/null +++ b/python/pyspark/streaming/kinesis.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.storagelevel import StorageLevel +from pyspark.streaming import DStream + +__all__ = ['KinesisUtils', 'InitialPositionInStream', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class KinesisUtils(object): + + @staticmethod + def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kinesis stream. This uses the + Kinesis Client Library (KCL) to pull messages from Kinesis. + + Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is + enabled. Make sure that your checkpoint directory is secure. + + :param ssc: StreamingContext object + :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to + update DynamoDB + :param streamName: Kinesis stream name + :param endpointUrl: Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + :param regionName: Name of region used by the Kinesis Client Library (KCL) to update + DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + :param initialPositionInStream: In the absence of Kinesis checkpoint info, this is the + worker's initial starting position in the stream. The + values are either the beginning of the stream per Kinesis' + limit of 24 hours (InitialPositionInStream.TRIM_HORIZON) or + the tip of the stream (InitialPositionInStream.LATEST). + :param checkpointInterval: Checkpoint interval for Kinesis checkpointing. See the Kinesis + Spark Streaming documentation for more details on the different + types of checkpoints. + :param storageLevel: Storage level to use for storing the received objects (default is + StorageLevel.MEMORY_AND_DISK_2) + :param awsAccessKeyId: AWS AccessKeyId (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param awsSecretKey: AWS SecretKey (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param decoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + jduration = ssc._jduration(checkpointInterval) + + try: + # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KinesisUtils._printErrorMsg(ssc.sparkContext) + raise e + stream = DStream(jstream, ssc, NoOpSerializer()) + return stream.map(lambda v: decoder(v)) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kinesis libraries not found in class path. Try one of the following. + + 1. Include the Kinesis library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kinesis-asl-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) + + +class InitialPositionInStream(object): + LATEST, TRIM_HORIZON = (0, 1) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 4ecae1e4bf282..5cd544b2144ef 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -36,9 +36,11 @@ import unittest from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream class PySparkStreamingTestCase(unittest.TestCase): @@ -891,6 +893,67 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + kinesisStream1 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + kinesisStream2 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + if os.environ.get('ENABLE_KINESIS_TESTS') != '1': + print("Skip test_kinesis_stream") + return + + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") + kinesisTestUtils = kinesisTestUtilsClz.newInstance() + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") @@ -926,10 +989,31 @@ def search_flume_assembly_jar(): else: return jars[0] + +def search_kinesis_asl_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + jars = glob.glob( + os.path.join(kinesis_asl_assembly_dir, + "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % + kinesis_asl_assembly_dir) + "You need to build Spark with " + "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " + "or 'build/mvn -Pkinesis-asl package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " + "remove all but one") % kinesis_asl_assembly_dir) + else: + return jars[0] + + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() - jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From d04634701413410938a133358fe1d9fbc077645e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:10:55 -0700 Subject: [PATCH 17/43] [SPARK-9504] [STREAMING] [TESTS] Use eventually to fix the flaky test The previous code uses `ssc.awaitTerminationOrTimeout(500)`. Since nobody will stop it during `awaitTerminationOrTimeout`, it's just like `sleep(500)`. In a super overloaded Jenkins worker, the receiver may be not able to start in 500 milliseconds. Verified this in the log of https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/39149/ There is no log about starting the receiver before this failure. That's why `assert(runningCount > 0)` failed. This PR replaces `awaitTerminationOrTimeout` with `eventually` which should be more reliable. Author: zsxwing Closes #7823 from zsxwing/SPARK-9504 and squashes the following commits: 7af66a6 [zsxwing] Remove wrong assertion 5ba2c99 [zsxwing] Use eventually to fix the flaky test --- .../apache/spark/streaming/StreamingContextSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 84a5fbb3d95eb..b7db280f63588 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -261,7 +261,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) - var runningCount = 0 + @volatile var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) input.count().foreachRDD { rdd => @@ -270,14 +270,14 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTerminationOrTimeout(500) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(runningCount > 0) + } ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) - assert(runningCount > 0) assert( - (TestReceiver.counter.get() == runningCount + 1) || - (TestReceiver.counter.get() == runningCount + 2), + TestReceiver.counter.get() == runningCount + 1, "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) From a8340fa7df17e3f0a3658f8b8045ab840845a72a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 31 Jul 2015 12:12:22 -0700 Subject: [PATCH 18/43] [SPARK-9481] Add logLikelihood to LocalLDAModel jkbradley Exposes `bound` (variational log likelihood bound) through public API as `logLikelihood`. Also adds unit tests, some DRYing of `LDASuite`, and includes unit tests mentioned in #7760 Author: Feynman Liang Closes #7801 from feynmanliang/SPARK-9481-logLikelihood and squashes the following commits: 6d1b2c9 [Feynman Liang] Negate perplexity definition 5f62b20 [Feynman Liang] Add logLikelihood --- .../spark/mllib/clustering/LDAModel.scala | 20 ++- .../spark/mllib/clustering/LDASuite.scala | 129 +++++++++--------- 2 files changed, 78 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 82281a0daf008..ff7035d2246c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -217,22 +217,28 @@ class LocalLDAModel private[clustering] ( LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) } - // TODO - // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * @param documents test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents, + docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + vocabSize) /** - * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * Calculate an upper bound bound on perplexity. See Equation (16) in original Online * LDA paper. * @param documents test corpus to use for calculating perplexity - * @return the log perplexity per word + * @return variational upper bound on log perplexity per word */ def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusWords = documents .map { case (_, termCounts) => termCounts.toArray.sum } .sum() - val batchVariationalBound = bound(documents, docConcentration, - topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) - val perWordBound = batchVariationalBound / corpusWords + val perWordBound = -logLikelihood(documents) / corpusWords perWordBound } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 695ee3b82efc5..79d2a1cafd1fa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -210,16 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with toy data") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -242,30 +233,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("LocalLDAModel logPerplexity") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + test("LocalLDAModel logLikelihood") { + val ldaModel: LocalLDAModel = toyModel - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) + val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + docsSingleWord = [[(0, 1.0)]] + docsRepeatedWord = [[(0, 5.0)]] + print(lda.bound(docsSingleWord)) + > -25.9706969833 + print(lda.bound(docsRepeatedWord)) + > -31.4413908227 + */ - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D) + assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D) + } + + test("LocalLDAModel logPerplexity") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -285,32 +291,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { > -3.69051285096 */ - assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition + assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D) } test("LocalLDAModel predict") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) - - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) - - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -351,16 +338,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with asymmetric prior") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -531,4 +509,27 @@ private[clustering] object LDASuite { def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter { case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0 } + + def toyData: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + def toyModel: LocalLDAModel = { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + ldaModel + } } From c0686668ae6a92b6bb4801a55c3b78aedbee816a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 31 Jul 2015 20:27:00 +0100 Subject: [PATCH 19/43] [SPARK-9202] capping maximum number of executor&driver information kept in Worker https://issues.apache.org/jira/browse/SPARK-9202 Author: CodingCat Closes #7714 from CodingCat/SPARK-9202 and squashes the following commits: 23977fb [CodingCat] add comments about why we don't synchronize finishedExecutors & finishedDrivers dc9772d [CodingCat] addressing the comments e125241 [CodingCat] stylistic fix 80bfe52 [CodingCat] fix JsonProtocolSuite d7d9485 [CodingCat] styistic fix and respect insert ordering 031755f [CodingCat] add license info & stylistic fix c3b5361 [CodingCat] test cases and docs c557b3a [CodingCat] applications are fine 9cac751 [CodingCat] application is fine... ad87ed7 [CodingCat] trimFinishedExecutorsAndDrivers --- .../apache/spark/deploy/worker/Worker.scala | 124 ++++++++++------ .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +- .../apache/spark/deploy/DeployTestUtils.scala | 89 ++++++++++++ .../spark/deploy/JsonProtocolSuite.scala | 59 ++------ .../spark/deploy/worker/WorkerSuite.scala | 133 +++++++++++++++++- docs/configuration.md | 14 ++ 6 files changed, 329 insertions(+), 94 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 82e9578bbcba5..0276c24f85368 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -25,7 +25,7 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random import scala.util.control.NonFatal @@ -115,13 +115,18 @@ private[worker] class Worker( } var workDir: File = null - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] val executors = new HashMap[String, ExecutorRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. private val shuffleService = new ExternalShuffleService(conf, securityMgr) @@ -461,25 +466,7 @@ private[worker] class Worker( } case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => - sendToMaster(executorStateChanged) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -523,24 +510,8 @@ private[worker] class Worker( } } - case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - sendToMaster(driverStageChanged) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } case ReregisterWithMaster => @@ -614,6 +585,78 @@ private[worker] class Worker( webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } private[deploy] object Worker extends Logging { @@ -669,5 +712,4 @@ private[deploy] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 334a5b10142aa..709a27233598c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -53,6 +53,8 @@ class WorkerWebUI( } } -private[ui] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala new file mode 100644 index 0000000000000..967aa0976f0ce --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.util.Date + +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.{SecurityManager, SparkConf} + +private[deploy] object DeployTestUtils { + def createAppDesc(): ApplicationDescription = { + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") + } + + def createAppInfo() : ApplicationInfo = { + val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, + "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + appInfo.endTime = JsonConstants.currTimeInMillis + appInfo + } + + def createDriverCommand(): Command = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") + ) + + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", + createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis + workerInfo + } + + def createExecutorRunner(execId: Int): ExecutorRunner = { + new ExecutorRunner( + "appId", + execId, + createAppDesc(), + 4, + 1234, + null, + "workerId", + "host", + 123, + "publicAddress", + new File("sparkHome"), + new File("workDir"), + "akka://worker", + new SparkConf, + Seq("localDir"), + ExecutorState.RUNNING) + } + + def createDriverRunner(driverId: String): DriverRunner = { + val conf = new SparkConf() + new DriverRunner( + conf, + driverId, + new File("workDir"), + new File("sparkHome"), + createDriverDesc(), + null, + "akka://worker", + new SecurityManager(conf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 08529e0ef2806..0a9f128a3a6b6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import java.io.File import java.util.Date import com.fasterxml.jackson.core.JsonParseException @@ -25,12 +24,14 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { + import org.apache.spark.deploy.DeployTestUtils._ + test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) @@ -50,7 +51,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { } test("writeExecutorRunner") { - val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) + val output = JsonProtocol.writeExecutorRunner(createExecutorRunner(123)) assertValidJson(output) assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } @@ -77,9 +78,10 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeWorkerState") { val executors = List[ExecutorRunner]() - val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) - val drivers = List(createDriverRunner()) - val finishedDrivers = List(createDriverRunner(), createDriverRunner()) + val finishedExecutors = List[ExecutorRunner](createExecutorRunner(123), + createExecutorRunner(123)) + val drivers = List(createDriverRunner("driverId")) + val finishedDrivers = List(createDriverRunner("driverId"), createDriverRunner("driverId")) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) @@ -87,47 +89,6 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } - def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") - } - - def createAppInfo() : ApplicationInfo = { - val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) - appInfo.endTime = JsonConstants.currTimeInMillis - appInfo - } - - def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") - ) - - def createDriverDesc(): DriverDescription = - new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) - - def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) - - def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") - workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis - workerInfo - } - - def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, - "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", - new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - } - - def createDriverRunner(): DriverRunner = { - val conf = new SparkConf() - new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) - } - def assertValidJson(json: JValue) { try { JsonMethods.parse(JsonMethods.compact(json)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 0f4d3b28d09df..faed4bdc68447 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command - import org.scalatest.Matchers +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} + class WorkerSuite extends SparkFunSuite with Matchers { + import org.apache.spark.deploy.DeployTestUtils._ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -56,4 +61,126 @@ class WorkerSuite extends SparkFunSuite with Matchers { "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") } + + test("test clearing of finishedExecutors (small number of executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 4) + for (i <- 1 until 5) { + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 2) + if (i > 1) { + assert(!worker.finishedExecutors.contains(s"app1/${i - 2}")) + } + assert(worker.executors.size === 4 - i) + } + } + + test("test clearing of finishedExecutors (more executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedExecutors.size < 30) { + worker.finishedExecutors.size + 1 + } else { + 28 + } + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedExecutors.contains(s"app1/$j")) + } + } + assert(worker.executors.size === 49 - i) + assert(worker.finishedExecutors.size === expectedValue) + } + } + + test("test clearing of finishedDrivers (small number of drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.drivers.size === 4) + assert(worker.finishedDrivers.size === 1) + for (i <- 1 until 5) { + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (i > 1) { + assert(!worker.finishedDrivers.contains(s"driverId-${i - 2}")) + } + assert(worker.drivers.size === 4 - i) + assert(worker.finishedDrivers.size === 2) + } + } + + test("test clearing of finishedDrivers (more drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.finishedDrivers.size === 1) + assert(worker.drivers.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedDrivers.size < 30) { + worker.finishedDrivers.size + 1 + } else { + 28 + } + } + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedDrivers.contains(s"driverId-$j")) + } + } + assert(worker.drivers.size === 49 - i) + assert(worker.finishedDrivers.size === expectedValue) + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index fd236137cb96e..24b606356a149 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -557,6 +557,20 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.worker.ui.retainedExecutors + 1000 + + How many finished executors the Spark UI and status APIs remember before garbage collecting. + + + + spark.worker.ui.retainedDrivers + 1000 + + How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization From 3c0d2e55210735e0df2f8febb5f63c224af230e3 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Fri, 31 Jul 2015 13:01:10 -0700 Subject: [PATCH 20/43] [SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests --- .../spark/mllib/clustering/LDAModel.scala | 37 +++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 22 +++++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ff7035d2246c2..0cdac84eeb591 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -516,6 +516,43 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top documents for each topic + * + * This is approximate; it may not return exactly the top-weighted documents for each topic. + * To get a more precise set of top documents, increase maxDocumentsPerTopic. + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 79d2a1cafd1fa..f2b94707fd0ff 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { From 060c79aab58efd4ce7353a1b00534de0d9e1de0b Mon Sep 17 00:00:00 2001 From: Sameer Abhyankar Date: Fri, 31 Jul 2015 13:08:55 -0700 Subject: [PATCH 21/43] [SPARK-9056] [STREAMING] Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Author: Sameer Abhyankar Author: Sameer Abhyankar Closes #7740 from sabhyankar/spark_branch_9056 and squashes the following commits: d5b2f1f [Sameer Abhyankar] Correct deprecated version to 1.5 1268133 [Sameer Abhyankar] Add {} and indentation ddf9844 [Sameer Abhyankar] Change 4 space indentation to 2 space indentation 1819b5f [Sameer Abhyankar] Use spark.streaming.fileStream.minRememberDuration property in lieu of spark.streaming.minRememberDuration --- core/src/main/scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/streaming/dstream/FileInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4161792976c7b..08bab4bf2739f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -548,7 +548,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.askTimeout" -> Seq( AlternateConfig("spark.akka.askTimeout", "1.4")), "spark.rpc.lookupTimeout" -> Seq( - AlternateConfig("spark.akka.lookupTimeout", "1.4")) + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) ) /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index dd4da9d9ca6a2..c358f5b5bd70b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -86,8 +86,10 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * Files with mod times older than this "window" of remembering will be ignored. So if new * files are visible within this window, then the file will get selected in the next batch. */ - private val minRememberDurationS = - Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + private val minRememberDurationS = { + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.fileStream.minRememberDuration", + ssc.conf.get("spark.streaming.minRememberDuration", "60s"))) + } // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock From fbef566a107b47e5fddde0ea65b8587d5039062d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 13:11:42 -0700 Subject: [PATCH 22/43] [SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities --- .../spark/ml/classification/NaiveBayes.scala | 65 ++++++++++++++----- .../ml/classification/NaiveBayesSuite.scala | 54 ++++++++++++++- 2 files changed, 101 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5be35fe209291..b46b676204e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) - extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { def this() = this(Identifiable.randomUID("nb")) @@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } - override protected def predict(features: Vector): Double = { + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { $(modelType) match { case Multinomial => - val prob = theta.multiply(features) - BLAS.axpy(1.0, pi, prob) - prob.argmax + multinomialCalculation(features) case Bernoulli => - features.foreachActive{ (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") - } - } - val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, pi, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - prob.argmax + bernoulliCalculation(features) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 264bde3703c5f..aea3d9b694490 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml.classification +import breeze.linalg.{Vector => BV} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,6 +31,8 @@ import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { case Row(prediction: Double, label: Double) => @@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") } test("Naive Bayes Bernoulli") { @@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") } } From 815c8245f47e61226a04e2e02f508457b5e9e536 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 31 Jul 2015 13:45:12 -0700 Subject: [PATCH 23/43] [SPARK-9466] [SQL] Increate two timeouts in CliSuite. Hopefully this can resolve the flakiness of this suite. JIRA: https://issues.apache.org/jira/browse/SPARK-9466 Author: Yin Huai Closes #7777 from yhuai/SPARK-9466 and squashes the following commits: e0e3a86 [Yin Huai] Increate the timeout. --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 13b0c5951dddc..df80d04b40801 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -137,7 +137,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with --database") { - runCliWithin(1.minute)( + runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" @@ -148,7 +148,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "Time taken: " ) - runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( "" -> "OK", "" From 873ab0f9692d8ea6220abdb8d9200041068372a8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 13:45:28 -0700 Subject: [PATCH 24/43] [SPARK-9490] [DOCS] [MLLIB] MLlib evaluation metrics guide example python code uses deprecated print statement Use print(x) not print x for Python 3 in eval examples CC sethah mengxr -- just wanted to close this out before 1.5 Author: Sean Owen Closes #7822 from srowen/SPARK-9490 and squashes the following commits: 01abeba [Sean Owen] Change "print x" to "print(x)" in the rest of the docs too bd7f7fb [Sean Owen] Use print(x) not print x for Python 3 in eval examples --- docs/ml-guide.md | 2 +- docs/mllib-evaluation-metrics.md | 66 ++++++++++++++--------------- docs/mllib-feature-extraction.md | 2 +- docs/mllib-statistics.md | 20 ++++----- docs/quick-start.md | 2 +- docs/sql-programming-guide.md | 6 +-- docs/streaming-programming-guide.md | 2 +- 7 files changed, 50 insertions(+), 50 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 8c46adf256a9a..b6ca50e98db02 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -561,7 +561,7 @@ test = sc.parallelize([(4L, "spark i j k"), prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() {% endhighlight %} diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 4ca0bb06b26a6..7066d5c97418c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -302,10 +302,10 @@ predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp metrics = BinaryClassificationMetrics(predictionAndLabels) # Area under precision-recall curve -print "Area under PR = %s" % metrics.areaUnderPR +print("Area under PR = %s" % metrics.areaUnderPR) # Area under ROC curve -print "Area under ROC = %s" % metrics.areaUnderROC +print("Area under ROC = %s" % metrics.areaUnderROC) {% endhighlight %} @@ -606,24 +606,24 @@ metrics = MulticlassMetrics(predictionAndLabels) precision = metrics.precision() recall = metrics.recall() f1Score = metrics.fMeasure() -print "Summary Stats" -print "Precision = %s" % precision -print "Recall = %s" % recall -print "F1 Score = %s" % f1Score +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) # Statistics by class labels = data.map(lambda lp: lp.label).distinct().collect() for label in sorted(labels): - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) # Weighted stats -print "Weighted recall = %s" % metrics.weightedRecall -print "Weighted precision = %s" % metrics.weightedPrecision -print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() -print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) -print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) {% endhighlight %} @@ -881,28 +881,28 @@ scoreAndLabels = sc.parallelize([ metrics = MultilabelMetrics(scoreAndLabels) # Summary stats -print "Recall = %s" % metrics.recall() -print "Precision = %s" % metrics.precision() -print "F1 measure = %s" % metrics.f1Measure() -print "Accuracy = %s" % metrics.accuracy +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) # Individual label stats labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() for label in labels: - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) # Micro stats -print "Micro precision = %s" % metrics.microPrecision -print "Micro recall = %s" % metrics.microRecall -print "Micro F1 measure = %s" % metrics.microF1Measure +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) # Hamming loss -print "Hamming loss = %s" % metrics.hammingLoss +print("Hamming loss = %s" % metrics.hammingLoss) # Subset accuracy -print "Subset accuracy = %s" % metrics.subsetAccuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) {% endhighlight %} @@ -1283,10 +1283,10 @@ scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) metrics = RegressionMetrics(scoreAndLabels) # Root mean sqaured error -print "RMSE = %s" % metrics.rootMeanSquaredError +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) {% endhighlight %} @@ -1479,17 +1479,17 @@ valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.l metrics = RegressionMetrics(valuesAndPreds) # Squared Error -print "MSE = %s" % metrics.meanSquaredError -print "RMSE = %s" % metrics.rootMeanSquaredError +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) # Mean absolute error -print "MAE = %s" % metrics.meanAbsoluteError +print("MAE = %s" % metrics.meanAbsoluteError) # Explained variance -print "Explained variance = %s" % metrics.explainedVariance +print("Explained variance = %s" % metrics.explainedVariance) {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index a69e41e2a1936..de86aba2ae627 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -221,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de5d6485f9b5f..be04d0b4b53a8 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %} @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %} @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,8 +415,8 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index bb39e4111f244..ce2cc9d2169cd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95945eb7fc8a0..d31baa080cbce 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -570,7 +570,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1 # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} @@ -752,7 +752,7 @@ results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) {% endhighlight %} @@ -1006,7 +1006,7 @@ parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 2f3013b533eb0..4663b3f14c527 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1525,7 +1525,7 @@ def getSqlContextInstance(sparkContext): words = ... # DStream of strings def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext sqlContext = getSqlContextInstance(rdd.context) From 6e5fd613ea4b9aa0ab485ba681277a51a4367168 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 21:51:55 +0100 Subject: [PATCH 25/43] [SPARK-9507] [BUILD] Remove dependency reduced POM hack now that shade plugin is updated Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here See https://issues.apache.org/jira/browse/SPARK-8819 I verified that `mvn clean package -DskipTests` works with Maven 3.3.3. pwendell are you up for trying this for the 1.5.0 release? Author: Sean Owen Closes #7826 from srowen/SPARK-9507 and squashes the following commits: e0b0fd2 [Sean Owen] Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here --- dev/create-release/create-release.sh | 4 ++-- pom.xml | 33 +++++----------------------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 86a7a4068c40e..4311c8c9e4ca6 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-scala-version.sh 2.11 - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/pom.xml b/pom.xml index e351c7c19df96..1371a1b6bd9f1 100644 --- a/pom.xml +++ b/pom.xml @@ -160,9 +160,6 @@ 2.4.4 1.1.1.7 1.1.2 - - false - ${java.home} - ${create.dependency.reduced.pom} @@ -1836,26 +1835,6 @@ - - - release - - - true - - -