From 3508ce8a5a05d6cb122ad59ba33c3cc18e17e2a4 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Wed, 1 Oct 2014 15:44:41 -0700 Subject: [PATCH 01/36] [SPARK-3708][SQL] Backticks aren't handled correctly is aliases The below query gives error sql("SELECT k FROM (SELECT \`key\` AS \`k\` FROM src) a") It gives error because the aliases are not cleaned so it could not be resolved in further processing. Author: ravipesala Closes #2594 from ravipesala/SPARK-3708 and squashes the following commits: d55db54 [ravipesala] Fixed SPARK-3708 (Backticks aren't handled correctly is aliases) --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0aa6292c0184e..4f3f808c93dc8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -855,7 +855,7 @@ private[hive] object HiveQl { case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), alias)()) + Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) /* Hints are ignored */ case Token("TOK_HINTLIST", _) => None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 679efe082f2a0..3647bb1c4ce7d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -63,4 +63,10 @@ class SQLQuerySuite extends QueryTest { sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) } + + test("SPARK-3708 Backticks aren't handled correctly is aliases") { + checkAnswer( + sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), + sql("SELECT `key` FROM src").collect().toSeq) + } } From f315fb7efc95afb2cc1208159b48359ba56a010d Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 15:55:09 -0700 Subject: [PATCH 02/36] [SPARK-3705][SQL] Add case for VoidObjectInspector to cover NullType add case for VoidObjectInspector in ```inspectorToDataType``` Author: scwf Closes #2552 from scwf/inspectorToDataType and squashes the following commits: 453d892 [scwf] add case for VoidObjectInspector --- .../main/scala/org/apache/spark/sql/hive/HiveInspectors.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index fa889ec104c6e..d633c42c6bd67 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -213,6 +213,8 @@ private[hive] trait HiveInspectors { case _: JavaHiveDecimalObjectInspector => DecimalType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType + case _: WritableVoidObjectInspector => NullType + case _: JavaVoidObjectInspector => NullType } implicit class typeInfoConversions(dt: DataType) { From f84b228c4002073ee4ff53be50462a63f48bd508 Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Wed, 1 Oct 2014 15:57:06 -0700 Subject: [PATCH 03/36] [SPARK-3593][SQL] Add support for sorting BinaryType BinaryType is derived from NativeType and added Ordering support. Author: Venkata Ramana G Author: Venkata Ramana Gollamudi Closes #2617 from gvramana/binarytype_sort and squashes the following commits: 1cf26f3 [Venkata Ramana Gollamudi] Supported Sorting of BinaryType --- .../apache/spark/sql/catalyst/types/dataTypes.scala | 12 +++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 10 ++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index c7d73d3990c3a..ac043d4dd8eb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -157,8 +157,18 @@ case object StringType extends NativeType with PrimitiveType { def simpleString: String = "string" } -case object BinaryType extends DataType with PrimitiveType { +case object BinaryType extends NativeType with PrimitiveType { private[sql] type JvmType = Array[Byte] + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] val ordering = new Ordering[JvmType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + return x.length - y.length + } + } def simpleString: String = "binary" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 08376eb5e5c4e..fdf3a229a796e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -190,6 +190,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + checkAnswer( + sql("SELECT b FROM binaryData ORDER BY a ASC"), + (1 to 5).map(Row(_)).toSeq) + + checkAnswer( + sql("SELECT b FROM binaryData ORDER BY a DESC"), + (1 to 5).map(Row(_)).toSeq.reverse) + checkAnswer( sql("SELECT * FROM arrayData ORDER BY data[0] ASC"), arrayData.collect().sortBy(_.data(0)).toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index eb33a61c6e811..10b7979df7375 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -54,6 +54,16 @@ object TestData { TestData2(3, 2) :: Nil) testData2.registerTempTable("testData2") + case class BinaryData(a: Array[Byte], b: Int) + val binaryData: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + BinaryData("12".getBytes(), 1) :: + BinaryData("22".getBytes(), 5) :: + BinaryData("122".getBytes(), 3) :: + BinaryData("121".getBytes(), 2) :: + BinaryData("123".getBytes(), 4) :: Nil) + binaryData.registerTempTable("binaryData") + // TODO: There is no way to express null primitives as case classes currently... val testData3 = logical.LocalRelation('a.int, 'b.int).loadData( From a31f4ff22f98c01f0d9b7d1240080ff2689c1270 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 1 Oct 2014 16:00:29 -0700 Subject: [PATCH 04/36] [SQL] Made Command.sideEffectResult protected Considering `Command.executeCollect()` simply delegates to `Command.sideEffectResult`, we no longer need to leave the latter `protected[sql]`. Author: Cheng Lian Closes #2431 from liancheng/narrow-scope and squashes the following commits: 1bfc16a [Cheng Lian] Made Command.sideEffectResult protected --- .../apache/spark/sql/execution/commands.scala | 10 +++++----- .../org/apache/spark/sql/hive/HiveContext.scala | 2 +- .../sql/hive/execution/CreateTableAsSelect.scala | 16 ++++++++-------- .../execution/DescribeHiveTableCommand.scala | 2 +- .../spark/sql/hive/execution/NativeCommand.scala | 2 +- .../spark/sql/hive/execution/commands.scala | 6 +++--- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index c2f48a902a3e9..f88099ec0761e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -37,7 +37,7 @@ trait Command { * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] + protected lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] override def executeCollect(): Array[Row] = sideEffectResult.toArray @@ -53,7 +53,7 @@ case class SetCommand( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match { + override protected lazy val sideEffectResult: Seq[Row] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { @@ -121,7 +121,7 @@ case class ExplainCommand( extends LeafNode with Command { // Run through the optimizer to generate the physical plan. - override protected[sql] lazy val sideEffectResult: Seq[Row] = try { + override protected lazy val sideEffectResult: Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = context.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString @@ -141,7 +141,7 @@ case class ExplainCommand( case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult = { + override protected lazy val sideEffectResult = { if (doCache) { context.cacheTable(tableName) } else { @@ -161,7 +161,7 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { Row("# Registered as a temporary table", null, null) +: child.output.map(field => Row(field.name, field.dataType.toString, null)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 3e1a7b71528e0..20ebe4996c207 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -404,7 +404,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // be similar with Hive. describeHiveTableCommand.hiveString case command: PhysicalCommand => - command.sideEffectResult.map(_.head.toString) + command.executeCollect().map(_.head.toString) case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 1017fe6d5396d..3625708d03175 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -30,23 +30,23 @@ import org.apache.spark.sql.hive.MetastoreRelation * Create table and insert the query result into it. * @param database the database name of the new relation * @param tableName the table name of the new relation - * @param insertIntoRelation function of creating the `InsertIntoHiveTable` + * @param insertIntoRelation function of creating the `InsertIntoHiveTable` * by specifying the `MetaStoreRelation`, the data will be inserted into that table. * TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc. */ @Experimental case class CreateTableAsSelect( - database: String, - tableName: String, - query: SparkPlan, - insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) - extends LeafNode with Command { + database: String, + tableName: String, + query: SparkPlan, + insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) + extends LeafNode with Command { def output = Seq.empty // A lazy computing of the metastoreRelation private[this] lazy val metastoreRelation: MetastoreRelation = { - // Create the table + // Create the table val sc = sqlContext.asInstanceOf[HiveContext] sc.catalog.createTable(database, tableName, query.output, false) // Get the Metastore Relation @@ -55,7 +55,7 @@ case class CreateTableAsSelect( } } - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { insertIntoRelation(metastoreRelation).execute Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 317801001c7a4..106cede9788ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( .mkString("\t") } - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala index 8f10e1ba7f426..6930c2babd117 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -32,7 +32,7 @@ case class NativeCommand( @transient context: HiveContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) + override protected lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) override def otherCopyArgs = context :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index d61c5e274a596..0fc674af31885 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -37,7 +37,7 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command { def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { hiveContext.analyze(tableName) Seq.empty[Row] } @@ -53,7 +53,7 @@ case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(None, tableName) @@ -70,7 +70,7 @@ case class AddJar(path: String) extends LeafNode with Command { override def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Row] = { + override protected lazy val sideEffectResult: Seq[Row] = { hiveContext.runSqlHive(s"ADD JAR $path") hiveContext.sparkContext.addJar(path) Seq.empty[Row] From 4e79970d32f9b917590dba8319bdc677e3bdd63a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 1 Oct 2014 16:03:00 -0700 Subject: [PATCH 05/36] Revert "[SPARK-3755][Core] Do not bind port 1 - 1024 to server in spark" This reverts commit 6390aae4eacbabfb1c53fb828b824f6a6518beff. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e5b83c069d961..b3025c6ec3364 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1439,7 +1439,7 @@ private[spark] object Utils extends Logging { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port - val tryPort = if (startPort == 0) startPort else (startPort + offset) % (65536 - 1024) + 1024 + val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 try { val (service, port) = startService(tryPort) logInfo(s"Successfully started service$serviceString on port $port.") From 45e058ca4babbe3cef6524b6a0f48b466a5139bf Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 1 Oct 2014 16:30:28 -0700 Subject: [PATCH 06/36] [SPARK-3729][SQL] Do all hive session state initialization in lazy val This change avoids a NPE during context initialization when settings are present. Author: Michael Armbrust Closes #2583 from marmbrus/configNPE and squashes the following commits: da2ec57 [Michael Armbrust] Do all hive session state initilialization in lazy val --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 ++++---- .../main/scala/org/apache/spark/sql/hive/TestHive.scala | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 20ebe4996c207..fdb56901f9ddb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,12 +231,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + SessionState.start(ss) + ss.err = new PrintStream(outputBuffer, true, "UTF-8") + ss.out = new PrintStream(outputBuffer, true, "UTF-8") + ss } - sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") - sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def setConf(key: String, value: String): Unit = { super.setConf(key, value) runSqlHive(s"SET $key=$value") @@ -273,7 +274,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { results } - SessionState.start(sessionState) /** * Execute the command using Hive and return the results as a sequence. Each element diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 70fb15259e7d7..4a999b98ad92b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -40,8 +40,10 @@ import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ +// SPARK-3729: Test key required to check for initialization errors with config. object TestHive - extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) + extends TestHiveContext( + new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) /** * A locally running test instance of Spark's Hive execution engine. From 1b9f0d67f28011cdff316042b344c9891f986aaa Mon Sep 17 00:00:00 2001 From: scwf Date: Wed, 1 Oct 2014 16:38:10 -0700 Subject: [PATCH 07/36] [SPARK-3704][SQL] Fix ColumnValue type for Short values in thrift server case ```ShortType```, we should add short value to hive row. Int value may lead to some problems. Author: scwf Closes #2551 from scwf/fix-addColumnValue and squashes the following commits: 08bcc59 [scwf] ColumnValue.shortValue for short type --- .../hive/thriftserver/server/SparkSQLOperationManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index bd3f68d92d8c7..910174a153768 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -113,7 +113,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) case ByteType => to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) case ShortType => - to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) case TimestampType => to.addColumnValue( ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) @@ -145,7 +145,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) case ByteType => to.addColumnValue(ColumnValue.byteValue(null)) case ShortType => - to.addColumnValue(ColumnValue.intValue(null)) + to.addColumnValue(ColumnValue.shortValue(null)) case TimestampType => to.addColumnValue(ColumnValue.timestampValue(null)) case BinaryType | _: ArrayType | _: StructType | _: MapType => From 93861a5e876fa57f509cce82768656ddf8d4ef00 Mon Sep 17 00:00:00 2001 From: aniketbhatnagar Date: Wed, 1 Oct 2014 18:31:18 -0700 Subject: [PATCH 08/36] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile This patch forces use of commons http client 4.2 in Kinesis-asl profile so that the AWS SDK does not run into dependency conflicts Author: aniketbhatnagar Closes #2535 from aniketbhatnagar/Kinesis-HttpClient-Dep-Fix and squashes the following commits: aa2079f [aniketbhatnagar] Merge branch 'Kinesis-HttpClient-Dep-Fix' of https://github.com/aniketbhatnagar/spark into Kinesis-HttpClient-Dep-Fix 73f55f6 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile 70cc75b [aniketbhatnagar] deleted merge files 725dbc9 [aniketbhatnagar] Merge remote-tracking branch 'origin/Kinesis-HttpClient-Dep-Fix' into Kinesis-HttpClient-Dep-Fix 4ed61d8 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile 9cd6103 [aniketbhatnagar] SPARK-3638 | Forced a compatible version of http client in kinesis-asl profile --- assembly/pom.xml | 10 ++++++++++ examples/pom.xml | 5 +++++ pom.xml | 1 + 3 files changed, 16 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index 5ec9da22ae83f..31a01e4d8e1de 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -349,5 +349,15 @@ + + kinesis-asl + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + diff --git a/examples/pom.xml b/examples/pom.xml index 2b561857f9f33..eb49a0e5af22d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -43,6 +43,11 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + diff --git a/pom.xml b/pom.xml index 70cb9729ff6d3..7756c89b00cad 100644 --- a/pom.xml +++ b/pom.xml @@ -138,6 +138,7 @@ 0.7.1 1.8.3 1.1.0 + 4.2.6 64m 512m From 29c3513203218af33bea2f6d99d622cf263d55dd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 1 Oct 2014 19:24:22 -0700 Subject: [PATCH 09/36] [SPARK-3446] Expose underlying job ids in FutureAction. FutureAction is the only type exposed through the async APIs, so for job IDs to be useful they need to be exposed there. The complication is that some async jobs run more than one job (e.g. takeAsync), so the exposed ID has to actually be a list of IDs that can actually change over time. So the interface doesn't look very nice, but... Change is actually small, I just added a basic test to make sure it works. Author: Marcelo Vanzin Closes #2337 from vanzin/SPARK-3446 and squashes the following commits: e166a68 [Marcelo Vanzin] Fix comment. 1fed2bc [Marcelo Vanzin] [SPARK-3446] Expose underlying job ids in FutureAction. --- .../scala/org/apache/spark/FutureAction.scala | 19 ++++++- .../org/apache/spark/FutureActionSuite.scala | 49 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/FutureActionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 75ea535f2f57b..e8f761eaa5799 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -83,6 +83,15 @@ trait FutureAction[T] extends Future[T] { */ @throws(classOf[Exception]) def get(): T = Await.result(this, Duration.Inf) + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + def jobIds: Seq[Int] + } @@ -150,8 +159,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } } - /** Get the corresponding job id for this action. */ - def jobId = jobWaiter.jobId + def jobIds = Seq(jobWaiter.jobId) } @@ -171,6 +179,8 @@ class ComplexFutureAction[T] extends FutureAction[T] { // is cancelled before the action was even run (and thus we have no thread to interrupt). @volatile private var _cancelled: Boolean = false + @volatile private var jobs: Seq[Int] = Nil + // A promise used to signal the future. private val p = promise[T]() @@ -219,6 +229,8 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } + this.jobs = jobs ++ job.jobIds + // Wait for the job to complete. If the action is cancelled (with an interrupt), // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. @@ -255,4 +267,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def isCompleted: Boolean = p.isCompleted override def value: Option[Try[T]] = p.future.value + + def jobIds = jobs + } diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala new file mode 100644 index 0000000000000..db9c25fc457a4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} + +import org.apache.spark.SparkContext._ + +class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext { + + before { + sc = new SparkContext("local", "FutureActionSuite") + } + + test("simple async action") { + val rdd = sc.parallelize(1 to 10, 2) + val job = rdd.countAsync() + val res = Await.result(job, Duration.Inf) + res should be (10) + job.jobIds.size should be (1) + } + + test("complex async action") { + val rdd = sc.parallelize(1 to 15, 3) + val job = rdd.takeAsync(10) + val res = Await.result(job, Duration.Inf) + res should be (1 to 10) + job.jobIds.size should be (2) + } + +} From f341e1c8a284b55cceb367a432c1fa5203692155 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 1 Oct 2014 23:08:51 -0700 Subject: [PATCH 10/36] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1375 (close requested by 'pwendell') Closes #476 (close requested by 'mengxr') Closes #2502 (close requested by 'pwendell') Closes #2391 (close requested by 'andrewor14') From bbdf1de84ffdd3bd172f17975d2f1422a9bcf2c6 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Wed, 1 Oct 2014 23:53:21 -0700 Subject: [PATCH 11/36] [SPARK-3371][SQL] Renaming a function expression with group by gives error The following code gives error. ``` sqlContext.registerFunction("len", (s: String) => s.length) sqlContext.sql("select len(foo) as a, count(1) from t1 group by len(foo)").collect() ``` Because SQl parser creates the aliases to the functions in grouping expressions with generated alias names. So if user gives the alias names to the functions inside projection then it does not match the generated alias name of grouping expression. This kind of queries are working in Hive. So the fix I have given that if user provides alias to the function in projection then don't generate alias in grouping expression,use the same alias. Author: ravipesala Closes #2511 from ravipesala/SPARK-3371 and squashes the following commits: 9fb973f [ravipesala] Removed aliases to grouping expressions. f8ace79 [ravipesala] Fixed the testcase issue bad2fd0 [ravipesala] SPARK-3371 : Fixed Renaming a function expression with group by gives error --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 862f78702c4e6..26336332c05a2 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -166,7 +166,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val withFilter = f.map(f => Filter(f, base)).getOrElse(base) val withProjection = g.map {g => - Aggregate(assignAliases(g), assignAliases(p), withFilter) + Aggregate(g, assignAliases(p), withFilter) }.getOrElse(Project(assignAliases(p), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fdf3a229a796e..6fb6cb8db0c8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,4 +680,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } + + test("SPARK-3371 Renaming a function expression with group by gives error") { + registerFunction("len", (s: String) => s.length) + checkAnswer( + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} } From 6e27cb630de69fa5acb510b4e2f6b980742b1957 Mon Sep 17 00:00:00 2001 From: Colin Patrick Mccabe Date: Thu, 2 Oct 2014 00:29:31 -0700 Subject: [PATCH 12/36] SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks This change reorders the replicas returned by HadoopRDD#getPreferredLocations so that replicas cached by HDFS are at the start of the list. This requires Hadoop 2.5 or higher; previous versions of Hadoop do not expose the information needed to determine whether a replica is cached. Author: Colin Patrick Mccabe Closes #1486 from cmccabe/SPARK-1767 and squashes the following commits: 338d4f8 [Colin Patrick Mccabe] SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks --- .../org/apache/spark/rdd/HadoopRDD.scala | 60 +++++++++++++++++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 18 +++++- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/scheduler/TaskLocation.scala | 48 +++++++++++++-- .../spark/scheduler/TaskSetManager.scala | 25 +++++++- .../spark/scheduler/TaskSetManagerSuite.scala | 22 +++++++ project/MimaExcludes.scala | 2 + 8 files changed, 162 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 21d0cc7b5cbaa..6b63eb23e9ee1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -23,6 +23,7 @@ import java.io.EOFException import scala.collection.immutable.Map import scala.reflect.ClassTag +import scala.collection.mutable.ListBuffer import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit @@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataReadMethod, InputMetrics} import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} /** @@ -249,9 +251,21 @@ class HadoopRDD[K, V]( } override def getPreferredLocations(split: Partition): Seq[String] = { - // TODO: Filtering out "localhost" in case of file:// URLs - val hadoopSplit = split.asInstanceOf[HadoopPartition] - hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") + val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value + val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val lsplit = c.inputSplitWithLocationInfo.cast(hsplit) + val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e: Exception => + logDebug("Failed to use InputSplitWithLocations.", e) + None + } + case None => None + } + locs.getOrElse(hsplit.getLocations.filter(_ != "localhost")) } override def checkpoint() { @@ -261,7 +275,7 @@ class HadoopRDD[K, V]( def getConf: Configuration = getJobConf() } -private[spark] object HadoopRDD { +private[spark] object HadoopRDD extends Logging { /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() @@ -309,4 +323,42 @@ private[spark] object HadoopRDD { f(inputSplit, firstParent[T].iterator(split, context)) } } + + private[spark] class SplitInfoReflections { + val inputSplitWithLocationInfo = + Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") + val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") + val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit") + val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") + val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo") + val isInMemory = splitLocationInfo.getMethod("isInMemory") + val getLocation = splitLocationInfo.getMethod("getLocation") + } + + private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try { + Some(new SplitInfoReflections) + } catch { + case e: Exception => + logDebug("SplitLocationInfo and other new Hadoop classes are " + + "unavailable. Using the older Hadoop location info code.", e) + None + } + + private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { + val out = ListBuffer[String]() + infos.foreach { loc => { + val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. + getLocation.invoke(loc).asInstanceOf[String] + if (locationStr != "localhost") { + if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory. + invoke(loc).asInstanceOf[Boolean]) { + logDebug("Partition " + locationStr + " is cached by Hadoop.") + out += new HDFSCacheTaskLocation(locationStr).toString + } else { + out += new HostTaskLocation(locationStr).toString + } + } + }} + out.seq + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 4c84b3f62354d..0cccdefc5ee09 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -173,9 +173,21 @@ class NewHadoopRDD[K, V]( new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) } - override def getPreferredLocations(split: Partition): Seq[String] = { - val theSplit = split.asInstanceOf[NewHadoopPartition] - theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") + override def getPreferredLocations(hsplit: Partition): Seq[String] = { + val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) } def getConf: Configuration = confBroadcast.value.value diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ab9e97c8fe409..2aba40d152e3e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag]( } /** - * Get the preferred locations of a partition (as hostnames), taking into account whether the + * Get the preferred locations of a partition, taking into account whether the * RDD is checkpointed. */ final def preferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5a96f52a10cd4..8135cdbb4c31f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1303,7 +1303,7 @@ class DAGScheduler( // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList if (!rddPrefs.isEmpty) { - return rddPrefs.map(host => TaskLocation(host)) + return rddPrefs.map(TaskLocation(_)) } // If the RDD has narrow dependencies, pick the first partition of the first narrow dep // that has any placement preferences. Ideally we would choose based on transfer sizes, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 67c9a6760b1b3..10c685f29d3ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -22,13 +22,51 @@ package org.apache.spark.scheduler * In the latter case, we will prefer to launch the task on that executorID, but our next level * of preference will be executors on the same host if this is not possible. */ -private[spark] -class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { - override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" +private[spark] sealed trait TaskLocation { + def host: String +} + +/** + * A location that includes both a host and an executor id on that host. + */ +private [spark] case class ExecutorCacheTaskLocation(override val host: String, + val executorId: String) extends TaskLocation { +} + +/** + * A location on a host. + */ +private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation { + override def toString = host +} + +/** + * A location on a host that is cached by HDFS. + */ +private [spark] case class HDFSCacheTaskLocation(override val host: String) + extends TaskLocation { + override def toString = TaskLocation.inMemoryLocationTag + host } private[spark] object TaskLocation { - def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) + // We identify hosts on which the block is cached with this prefix. Because this prefix contains + // underscores, which are not legal characters in hostnames, there should be no potential for + // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames. + val inMemoryLocationTag = "hdfs_cache_" + + def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId) - def apply(host: String) = new TaskLocation(host, None) + /** + * Create a TaskLocation from a string returned by getPreferredLocations. + * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the + * location is cached. + */ + def apply(str: String) = { + val hstr = str.stripPrefix(inMemoryLocationTag) + if (hstr.equals(str)) { + new HostTaskLocation(str) + } else { + new HostTaskLocation(hstr) + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d9d53faf843ff..a6c23fc85a1b0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -181,8 +181,24 @@ private[spark] class TaskSetManager( } for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) + loc match { + case e: ExecutorCacheTaskLocation => + addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer)) + case e: HDFSCacheTaskLocation => { + val exe = sched.getExecutorsAliveOnHost(loc.host) + exe match { + case Some(set) => { + for (e <- set) { + addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer)) + } + logInfo(s"Pending task $index has a cached location at ${e.host} " + + ", where there are executors " + set.mkString(",")) + } + case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + + ", but there are no executors alive there.") + } + } + case _ => Unit } addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { @@ -283,7 +299,10 @@ private[spark] class TaskSetManager( // on multiple nodes when we replicate cached blocks, as in Spark Streaming for (index <- speculatableTasks if canRunOnHost(index)) { val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) + val executors = prefs.flatMap(_ match { + case e: ExecutorCacheTaskLocation => Some(e.executorId) + case _ => None + }); if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 93e8ddacf8865..c0b07649eb6dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execC", "host3", ANY) !== None) } + test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(3, + Seq(HostTaskLocation("host1")), + Seq(HostTaskLocation("host2")), + Seq(HDFSCacheTaskLocation("host3"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execA") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execB") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + sched.removeExecutor("execC") + manager.executorAdded() + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + } def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4076ebc6fc8d5..d499302124461 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,6 +41,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ Seq( + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.TaskLocation"), // Added normL1 and normL2 to trait MultivariateStatisticalSummary ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), From 5b4a5b1acdc439a58aa2a3561ac0e3fb09f529d6 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Thu, 2 Oct 2014 11:13:19 -0700 Subject: [PATCH 13/36] [SPARK-3706][PySpark] Cannot run IPython REPL with IPYTHON set to "1" and PYSPARK_PYTHON unset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Problem The section "Using the shell" in Spark Programming Guide (https://spark.apache.org/docs/latest/programming-guide.html#using-the-shell) says that we can run pyspark REPL through IPython. But a folloing command does not run IPython but a default Python executable. ``` $ IPYTHON=1 ./bin/pyspark Python 2.7.8 (default, Jul 2 2014, 10:14:46) ... ``` the spark/bin/pyspark script on the commit b235e013638685758885842dc3268e9800af3678 decides which executable and options it use folloing way. 1. if PYSPARK_PYTHON unset * → defaulting to "python" 2. if IPYTHON_OPTS set * → set IPYTHON "1" 3. some python scripts passed to ./bin/pyspak → run it with ./bin/spark-submit * out of this issues scope 4. if IPYTHON set as "1" * → execute $PYSPARK_PYTHON (default: ipython) with arguments $IPYTHON_OPTS * otherwise execute $PYSPARK_PYTHON Therefore, when PYSPARK_PYTHON is unset, python is executed though IPYTHON is "1". In other word, when PYSPARK_PYTHON is unset, IPYTHON_OPS and IPYTHON has no effect on decide which command to use. PYSPARK_PYTHON | IPYTHON_OPTS | IPYTHON | resulting command | expected command ---- | ---- | ----- | ----- | ----- (unset → defaults to python) | (unset) | (unset) | python | (same) (unset → defaults to python) | (unset) | 1 | python | ipython (unset → defaults to python) | an_option | (unset → set to 1) | python an_option | ipython an_option (unset → defaults to python) | an_option | 1 | python an_option | ipython an_option ipython | (unset) | (unset) | ipython | (same) ipython | (unset) | 1 | ipython | (same) ipython | an_option | (unset → set to 1) | ipython an_option | (same) ipython | an_option | 1 | ipython an_option | (same) ### Suggestion The pyspark script should determine firstly whether a user wants to run IPython or other executables. 1. if IPYTHON_OPTS set * set IPYTHON "1" 2. if IPYTHON has a value "1" * PYSPARK_PYTHON defaults to "ipython" if not set 3. PYSPARK_PYTHON defaults to "python" if not set See the pull request for more detailed modification. Author: cocoatomo Closes #2554 from cocoatomo/issues/cannot-run-ipython-without-options and squashes the following commits: d2a9b06 [cocoatomo] [SPARK-3706][PySpark] Use PYTHONUNBUFFERED environment variable instead of -u option 264114c [cocoatomo] [SPARK-3706][PySpark] Remove the sentence about deprecated environment variables 42e02d5 [cocoatomo] [SPARK-3706][PySpark] Replace environment variables used to customize execution of PySpark REPL 10d56fb [cocoatomo] [SPARK-3706][PySpark] Cannot run IPython REPL with IPYTHON set to "1" and PYSPARK_PYTHON unset --- bin/pyspark | 24 +++++++++---------- .../apache/spark/deploy/PythonRunner.scala | 3 ++- docs/programming-guide.md | 8 +++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 5142411e36974..6655725ef8e8e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -52,10 +52,20 @@ fi # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then - PYSPARK_PYTHON="python" + if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then + # for backward compatibility + PYSPARK_PYTHON="ipython" + else + PYSPARK_PYTHON="python" + fi fi export PYSPARK_PYTHON +if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then + # for backward compatibility + PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" +fi + # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -64,11 +74,6 @@ export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" -# If IPython options are specified, assume user wants to run IPython -if [[ -n "$IPYTHON_OPTS" ]]; then - IPYTHON=1 -fi - # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. @@ -106,10 +111,5 @@ if [[ "$1" =~ \.py$ ]]; then else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 - # Only use ipython if no command line arguments were provided [SPARK-1134] - if [[ "$IPYTHON" = "1" ]]; then - exec ${PYSPARK_PYTHON:-ipython} $IPYTHON_OPTS - else - exec "$PYSPARK_PYTHON" - fi + exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS fi diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index b66c3ba4d5fb0..79b4d7ea41a33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -54,9 +54,10 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, "-u", formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 1d61a3c555eaf..8e8cc1dd983f8 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `IPYTHON` variable to `1` when running `bin/pyspark`: +use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ IPYTHON=1 ./bin/pyspark +$ PYSPARK_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `IPYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ IPYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} From 82a6a083a485140858bcd93d73adec59bb5cca64 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 2 Oct 2014 11:37:24 -0700 Subject: [PATCH 14/36] [SQL][Docs] Update the output of printSchema and fix a typo in SQL programming guide. We have changed the output format of `printSchema`. This PR will update our SQL programming guide to show the updated format. Also, it fixes a typo (the value type of `StructType` in Java API). Author: Yin Huai Closes #2630 from yhuai/sqlDoc and squashes the following commits: 267d63e [Yin Huai] Update the output of printSchema and fix a typo. --- docs/sql-programming-guide.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 818fd5ab80af8..368c3d0008b07 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -620,8 +620,8 @@ val people = sqlContext.jsonFile(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() // root -// |-- age: IntegerType -// |-- name: StringType +// |-- age: integer (nullable = true) +// |-- name: string (nullable = true) // Register this SchemaRDD as a table. people.registerTempTable("people") @@ -658,8 +658,8 @@ JavaSchemaRDD people = sqlContext.jsonFile(path); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); // root -// |-- age: IntegerType -// |-- name: StringType +// |-- age: integer (nullable = true) +// |-- name: string (nullable = true) // Register this JavaSchemaRDD as a table. people.registerTempTable("people"); @@ -697,8 +697,8 @@ people = sqlContext.jsonFile(path) # The inferred schema can be visualized using the printSchema() method. people.printSchema() # root -# |-- age: IntegerType -# |-- name: StringType +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) # Register this SchemaRDD as a table. people.registerTempTable("people") @@ -1394,7 +1394,7 @@ please use factory methods provided in StructType - org.apache.spark.sql.api.java + org.apache.spark.sql.api.java.Row DataType.createStructType(fields)
Note: fields is a List or an array of StructFields. From b4fb7b80a0d863500943d788ad3e34d502a6dafa Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Thu, 2 Oct 2014 13:48:35 -0500 Subject: [PATCH 15/36] Modify default YARN memory_overhead-- from an additive constant to a multiplier Redone against the recent master branch (https://github.com/apache/spark/pull/1391) Author: Nishkam Ravi Author: nravi Author: nishkamravi2 Closes #2485 from nishkamravi2/master_nravi and squashes the following commits: 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- docs/running-on-yarn.md | 8 ++++---- .../spark/deploy/yarn/ClientArguments.scala | 16 +++++++++------- .../apache/spark/deploy/yarn/ClientBase.scala | 12 ++++++++---- .../apache/spark/deploy/yarn/YarnAllocator.scala | 16 ++++++++-------- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 8 ++++++-- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4b3a49eca7007..695813a2ba881 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -79,16 +79,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.executor.memoryOverhead - 384 + executorMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead - 384 + driverMemory * 0.07, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. + The amount of off heap memory (in megabytes) to be allocated per driver. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 26dbd6237c6b8..a12f82d2fbe70 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf import org.apache.spark.util.{Utils, IntParam, MemoryParam} - +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { @@ -39,15 +39,17 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var appName: String = "Spark" var priority = 0 + parseArgs(args.toList) + loadEnvironmentArgs() + // Additional memory to allocate to containers // For now, use driver's memory overhead as our AM container's memory overhead - val amMemoryOverhead = sparkConf.getInt( - "spark.yarn.driver.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - val executorMemoryOverhead = sparkConf.getInt( - "spark.yarn.executor.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + val amMemoryOverhead = sparkConf.getInt("spark.yarn.driver.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toInt, MEMORY_OVERHEAD_MIN)) + + val executorMemoryOverhead = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) - parseArgs(args.toList) - loadEnvironmentArgs() validateArgs() /** Load any default arguments provided through environment variables and Spark properties. */ diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 1cf19c198509c..6ecac6eae6e03 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -64,14 +64,18 @@ private[spark] trait ClientBase extends Logging { s"memory capability of the cluster ($maxMem MB per container)") val executorMem = args.executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory ($executorMem MB) " + - s"is above the max threshold ($maxMem MB) of this cluster!") + throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory ($amMem MB) " + - s"is above the max threshold ($maxMem MB) of this cluster!") + throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster!") } + logInfo("Will allocate AM container, with %d MB memory including %d MB overhead".format( + amMem, + amMemoryOverhead)) + // We could add checks to make sure the entire cluster has enough resources but that involves // getting all the node reports and computing ourselves. } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 299e38a5eb9c0..4f4f1d2aaaade 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -33,6 +33,7 @@ import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ object AllocationType extends Enumeration { type AllocationType = Value @@ -78,10 +79,6 @@ private[yarn] abstract class YarnAllocator( // Containers to be released in next request to RM private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] - // Additional memory overhead - in mb. - protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - // Number of container requests that have been sent to, but not yet allocated by the // ApplicationMaster. private val numPendingAllocate = new AtomicInteger() @@ -97,6 +94,10 @@ private[yarn] abstract class YarnAllocator( protected val (preferredHostToCount, preferredRackToCount) = generateNodeToWeight(conf, preferredNodes) + // Additional memory overhead - in mb. + protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) + private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue sparkConf.getInt("spark.yarn.containerLauncherMaxThreads", 25), Integer.MAX_VALUE, @@ -114,12 +115,11 @@ private[yarn] abstract class YarnAllocator( // this is needed by alpha, do it here since we add numPending right after this val executorsPending = numPendingAllocate.get() - if (missing > 0) { + val totalExecutorMemory = executorMemory + memoryOverhead numPendingAllocate.addAndGet(missing) - logInfo("Will Allocate %d executor containers, each with %d memory".format( - missing, - (executorMemory + memoryOverhead))) + logInfo(s"Will allocate $missing executor containers, each with $totalExecutorMemory MB " + + s"memory including $memoryOverhead MB overhead") } else { logDebug("Empty allocation request ...") } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 0b712c201904a..e1e0144f46fe9 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -84,8 +84,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } object YarnSparkHadoopUtil { - // Additional memory overhead - in mb. - val DEFAULT_MEMORY_OVERHEAD = 384 + // Additional memory overhead + // 7% was arrived at experimentally. In the interest of minimizing memory waste while covering + // the common cases. Memory overhead tends to grow with container size. + + val MEMORY_OVERHEAD_FACTOR = 0.07 + val MEMORY_OVERHEAD_MIN = 384 val ANY_HOST = "*" From c6469a02f14e8c23e9b4e1336768f8bbfc15f5d8 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 2 Oct 2014 13:47:30 -0700 Subject: [PATCH 16/36] [SPARK-3766][Doc]Snappy is also the default compress codec for broadcast variables Author: scwf Closes #2632 from scwf/compress-doc and squashes the following commits: 7983a1a [scwf] snappy is the default compression codec for broadcast --- docs/configuration.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 791b6f2aa3261..316490f0f43fc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -413,10 +413,11 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, - Spark provides three codecs: lz4, lzf, and snappy. You - can also use fully qualified class names to specify the codec, e.g. - org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions, broadcast variables and + shuffle outputs. By default, Spark provides three codecs: lz4, lzf, + and snappy. You can also use fully qualified class names to specify the codec, + e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec. From 5db78e6b87d33ac2d48a997e69b46e9be3b63137 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 2 Oct 2014 13:49:47 -0700 Subject: [PATCH 17/36] [SPARK-3495] Block replication fails continuously when the replication target node is dead AND [SPARK-3496] Block replication by mistake chooses driver as target If a block manager (say, A) wants to replicate a block and the node chosen for replication (say, B) is dead, then the attempt to send the block to B fails. However, this continues to fail indefinitely. Even if the driver learns about the demise of the B, A continues to try replicating to B and failing miserably. The reason behind this bug is that A initially fetches a list of peers from the driver (when B was active), but never updates it after B is dead. This affects Spark Streaming as its receiver uses block replication. The solution in this patch adds the following. - Changed BlockManagerMaster to return all the peers of a block manager, rather than the requested number. It also filters out driver BlockManager. - Refactored BlockManager's replication code to handle peer caching correctly. + The peer for replication is randomly selected. This is different from past behavior where for a node A, a node B was deterministically chosen for the lifetime of the application. + If replication fails to one node, the peers are refetched. + The peer cached has a TTL of 1 second to enable discovery of new peers and using them for replication. - Refactored use of \ in BlockManager into a new method `BlockManagerId.isDriver` - Added replication unit tests (replication was not tested till now, duh!) This should not make a difference in performance of Spark workloads where replication is not used. @andrewor14 @JoshRosen Author: Tathagata Das Closes #2366 from tdas/replication-fix and squashes the following commits: 9690f57 [Tathagata Das] Moved replication tests to a new BlockManagerReplicationSuite. 0661773 [Tathagata Das] Minor changes based on PR comments. a55a65c [Tathagata Das] Added a unit test to test replication behavior. 012afa3 [Tathagata Das] Bug fix 89f91a0 [Tathagata Das] Minor change. 68e2c72 [Tathagata Das] Made replication peer selection logic more efficient. 08afaa9 [Tathagata Das] Made peer selection for replication deterministic to block id 3821ab9 [Tathagata Das] Fixes based on PR comments. 08e5646 [Tathagata Das] More minor changes. d402506 [Tathagata Das] Fixed imports. 4a20531 [Tathagata Das] Filtered driver block manager from peer list, and also consolidated the use of in BlockManager. 7598f91 [Tathagata Das] Minor changes. 03de02d [Tathagata Das] Change replication logic to correctly refetch peers from master on failure and on new worker addition. d081bf6 [Tathagata Das] Fixed bug in get peers and unit tests to test get-peers and replication under executor churn. 9f0ac9f [Tathagata Das] Modified replication tests to fail on replication bug. af0c1da [Tathagata Das] Added replication unit tests to BlockManagerSuite --- .../apache/spark/storage/BlockManager.scala | 122 ++++- .../apache/spark/storage/BlockManagerId.scala | 2 + .../spark/storage/BlockManagerMaster.scala | 9 +- .../storage/BlockManagerMasterActor.scala | 29 +- .../spark/storage/BlockManagerMessages.scala | 2 +- .../spark/broadcast/BroadcastSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 418 ++++++++++++++++++ .../spark/storage/BlockManagerSuite.scala | 9 +- 8 files changed, 544 insertions(+), 49 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d1bee3d2c033c..3f5d06e1aeee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -112,6 +113,11 @@ private[spark] class BlockManager( private val broadcastCleaner = new MetadataCleaner( MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) + // Field related to peer block managers that are necessary for block replication + @volatile private var cachedPeers: Seq[BlockManagerId] = _ + private val peerFetchLock = new Object + private var lastPeerFetchTime = 0L + initialize() /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -787,31 +793,111 @@ private[spark] class BlockManager( } /** - * Replicate block to another node. + * Get peer block managers in the system. + */ + private def getPeers(forceFetch: Boolean): Seq[BlockManagerId] = { + peerFetchLock.synchronized { + val cachedPeersTtl = conf.getInt("spark.storage.cachedPeersTtl", 60 * 1000) // milliseconds + val timeout = System.currentTimeMillis - lastPeerFetchTime > cachedPeersTtl + if (cachedPeers == null || forceFetch || timeout) { + cachedPeers = master.getPeers(blockManagerId).sortBy(_.hashCode) + lastPeerFetchTime = System.currentTimeMillis + logDebug("Fetched peers from master: " + cachedPeers.mkString("[", ",", "]")) + } + cachedPeers + } + } + + /** + * Replicate block to another node. Not that this is a blocking call that returns after + * the block has been replicated. */ - @volatile var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel): Unit = { + val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) + val numPeersToReplicateTo = level.replication - 1 + val peersForReplication = new ArrayBuffer[BlockManagerId] + val peersReplicatedTo = new ArrayBuffer[BlockManagerId] + val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] val tLevel = StorageLevel( level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) - if (cachedPeers == null) { - cachedPeers = master.getPeers(blockManagerId, level.replication - 1) + val startTime = System.currentTimeMillis + val random = new Random(blockId.hashCode) + + var replicationFailed = false + var failures = 0 + var done = false + + // Get cached list of peers + peersForReplication ++= getPeers(forceFetch = false) + + // Get a random peer. Note that this selection of a peer is deterministic on the block id. + // So assuming the list of peers does not change and no replication failures, + // if there are multiple attempts in the same node to replicate the same block, + // the same set of peers will be selected. + def getRandomPeer(): Option[BlockManagerId] = { + // If replication had failed, then force update the cached list of peers and remove the peers + // that have been already used + if (replicationFailed) { + peersForReplication.clear() + peersForReplication ++= getPeers(forceFetch = true) + peersForReplication --= peersReplicatedTo + peersForReplication --= peersFailedToReplicateTo + } + if (!peersForReplication.isEmpty) { + Some(peersForReplication(random.nextInt(peersForReplication.size))) + } else { + None + } } - for (peer: BlockManagerId <- cachedPeers) { - val start = System.nanoTime - data.rewind() - logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + - s"To node: $peer") - try { - blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) - } catch { - case e: Exception => - logError(s"Failed to replicate block to $peer", e) + // One by one choose a random peer and try uploading the block to it + // If replication fails (e.g., target peer is down), force the list of cached peers + // to be re-fetched from driver and then pick another random peer for replication. Also + // temporarily black list the peer for which replication failed. + // + // This selection of a peer and replication is continued in a loop until one of the + // following 3 conditions is fulfilled: + // (i) specified number of peers have been replicated to + // (ii) too many failures in replicating to peers + // (iii) no peer left to replicate to + // + while (!done) { + getRandomPeer() match { + case Some(peer) => + try { + val onePeerStartTime = System.currentTimeMillis + data.rewind() + logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") + blockTransferService.uploadBlockSync( + peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" + .format((System.currentTimeMillis - onePeerStartTime))) + peersReplicatedTo += peer + peersForReplication -= peer + replicationFailed = false + if (peersReplicatedTo.size == numPeersToReplicateTo) { + done = true // specified number of peers have been replicated to + } + } catch { + case e: Exception => + logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) + failures += 1 + replicationFailed = true + peersFailedToReplicateTo += peer + if (failures > maxReplicationFailures) { // too many failures in replcating to peers + done = true + } + } + case None => // no peer left to replicate to + done = true } - - logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." - .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) + } + val timeTakeMs = (System.currentTimeMillis - startTime) + logDebug(s"Replicating $blockId of ${data.limit()} bytes to " + + s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms") + if (peersReplicatedTo.size < numPeersToReplicateTo) { + logWarning(s"Block $blockId replicated to only " + + s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d4487fce49ab6..142285094342c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -59,6 +59,8 @@ class BlockManagerId private ( def port: Int = port_ + def isDriver: Boolean = (executorId == "") + override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 2e262594b3538..d08e1419e3e41 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,13 +84,8 @@ class BlockManagerMaster( } /** Get ids of other nodes in the cluster from the driver */ - def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) - if (result.length != numPeers) { - throw new SparkException( - "Error getting peers, only got " + result.size + " instead of " + numPeers) - } - result + def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 1a6c7cb24f9ac..6a06257ed0c08 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -83,8 +83,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetLocationsMultipleBlockIds(blockIds) => sender ! getLocationsMultipleBlockIds(blockIds) - case GetPeers(blockManagerId, size) => - sender ! getPeers(blockManagerId, size) + case GetPeers(blockManagerId) => + sender ! getPeers(blockManagerId) case GetMemoryStatus => sender ! memoryStatus @@ -173,11 +173,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - // TODO: Consolidate usages of import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => - removeFromDriver || info.blockManagerId.executorId != "" + removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => @@ -212,7 +211,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val minSeenTime = now - slaveTimeout val toRemove = new mutable.HashSet[BlockManagerId] for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime && info.blockManagerId.executorId != "") { + if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) { logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") toRemove += info.blockManagerId @@ -232,7 +231,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus */ private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = { if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.executorId == "" && !isLocal + blockManagerId.isDriver && !isLocal } else { blockManagerInfo(blockManagerId).updateLastSeenMs() true @@ -355,7 +354,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus tachyonSize: Long) { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "" && !isLocal) { + if (blockManagerId.isDriver && !isLocal) { // We intentionally do not register the master (except in local mode), // so we should not indicate failure. sender ! true @@ -403,16 +402,14 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockIds.map(blockId => getLocations(blockId)) } - private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { - val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - - val selfIndex = peers.indexOf(blockManagerId) - if (selfIndex == -1) { - throw new SparkException("Self index for " + blockManagerId + " not found") + /** Get the list of the peers of the given block manager */ + private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { + val blockManagerIds = blockManagerInfo.keySet + if (blockManagerIds.contains(blockManagerId)) { + blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq + } else { + Seq.empty } - - // Note that this logic will select the same node multiple times if there aren't enough peers - Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2ba16b8476600..3db5dd9774ae8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -88,7 +88,7 @@ private[spark] object BlockManagerMessages { case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster - case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class RemoveExecutor(execId: String) extends ToBlockManagerMaster diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 978a6ded80829..acaf321de52fb 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -132,7 +132,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") + assert(bm.isDriver, "Block should only be on the driver") assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) assert(status.memSize > 0, "Block should be in memory store on the driver") assert(status.diskSize === 0, "Block should not be in disk store on the driver") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala new file mode 100644 index 0000000000000..1f1d53a1ee3b0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Props} +import org.mockito.Mockito.{mock, when} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.{AkkaUtils, SizeEstimator} + +/** Testsuite that tests block replication in BlockManager */ +class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { + + private val conf = new SparkConf(false) + var actorSystem: ActorSystem = null + var master: BlockManagerMaster = null + val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) + + // List of block manager created during an unit test, so that all of the them can be stopped + // after the unit test. + val allStores = new ArrayBuffer[BlockManager] + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + conf.set("spark.kryoserializer.buffer.mb", "1") + val serializer = new KryoSerializer(conf) + + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + + private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer) + allStores += store + store + } + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + "test", "localhost", 0, conf = conf, securityManager = securityMgr) + this.actorSystem = actorSystem + + conf.set("spark.authenticate", "false") + conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.storage.unrollFraction", "0.4") + conf.set("spark.storage.unrollMemoryThreshold", "512") + + // to make a replication attempt to inactive store fail fast + conf.set("spark.core.connection.ack.wait.timeout", "1") + // to make cached peers refresh frequently + conf.set("spark.storage.cachedPeersTtl", "10") + + master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + conf, true) + allStores.clear() + } + + after { + allStores.foreach { _.stop() } + allStores.clear() + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + master = null + } + + + test("get peers with addition and removal of block managers") { + val numStores = 4 + val stores = (1 to numStores - 1).map { i => makeBlockManager(1000, s"store$i") } + val storeIds = stores.map { _.blockManagerId }.toSet + assert(master.getPeers(stores(0).blockManagerId).toSet === + storeIds.filterNot { _ == stores(0).blockManagerId }) + assert(master.getPeers(stores(1).blockManagerId).toSet === + storeIds.filterNot { _ == stores(1).blockManagerId }) + assert(master.getPeers(stores(2).blockManagerId).toSet === + storeIds.filterNot { _ == stores(2).blockManagerId }) + + // Add driver store and test whether it is filtered out + val driverStore = makeBlockManager(1000, "") + assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver)) + assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver)) + assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver)) + + // Add a new store and test whether get peers returns it + val newStore = makeBlockManager(1000, s"store$numStores") + assert(master.getPeers(stores(0).blockManagerId).toSet === + storeIds.filterNot { _ == stores(0).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(stores(1).blockManagerId).toSet === + storeIds.filterNot { _ == stores(1).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(stores(2).blockManagerId).toSet === + storeIds.filterNot { _ == stores(2).blockManagerId } + newStore.blockManagerId) + assert(master.getPeers(newStore.blockManagerId).toSet === storeIds) + + // Remove a store and test whether get peers returns it + val storeIdToRemove = stores(0).blockManagerId + master.removeExecutor(storeIdToRemove.executorId) + assert(!master.getPeers(stores(1).blockManagerId).contains(storeIdToRemove)) + assert(!master.getPeers(stores(2).blockManagerId).contains(storeIdToRemove)) + assert(!master.getPeers(newStore.blockManagerId).contains(storeIdToRemove)) + + // Test whether asking for peers of a unregistered block manager id returns empty list + assert(master.getPeers(stores(0).blockManagerId).isEmpty) + assert(master.getPeers(BlockManagerId("", "", 1)).isEmpty) + } + + + test("block replication - 2x replication") { + testReplication(2, + Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER_2) + ) + } + + test("block replication - 3x replication") { + // Generate storage levels with 3x replication + val storageLevels = { + Seq(MEMORY_ONLY, MEMORY_ONLY_SER, DISK_ONLY, MEMORY_AND_DISK, MEMORY_AND_DISK_SER).map { + level => StorageLevel( + level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 3) + } + } + testReplication(3, storageLevels) + } + + test("block replication - mixed between 1x to 5x") { + // Generate storage levels with varying replication + val storageLevels = Seq( + MEMORY_ONLY, + MEMORY_ONLY_SER_2, + StorageLevel(true, false, false, false, 3), + StorageLevel(true, true, false, true, 4), + StorageLevel(true, true, false, false, 5), + StorageLevel(true, true, false, true, 4), + StorageLevel(true, false, false, false, 3), + MEMORY_ONLY_SER_2, + MEMORY_ONLY + ) + testReplication(5, storageLevels) + } + + test("block replication - 2x replication without peers") { + intercept[org.scalatest.exceptions.TestFailedException] { + testReplication(1, + Seq(StorageLevel.MEMORY_AND_DISK_2, StorageLevel(true, false, false, false, 3))) + } + } + + test("block replication - deterministic node selection") { + val blockSize = 1000 + val storeSize = 10000 + val stores = (1 to 5).map { + i => makeBlockManager(storeSize, s"store$i") + } + val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 + val storageLevel3x = StorageLevel(true, true, false, true, 3) + val storageLevel4x = StorageLevel(true, true, false, true, 4) + + def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { + stores.head.putSingle(blockId, new Array[Byte](blockSize), level) + val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet + stores.foreach { _.removeBlock(blockId) } + master.removeBlock(blockId) + locations + } + + // Test if two attempts to 2x replication returns same set of locations + val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) + assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, + "Inserting a 2x replicated block second time gave different locations from the first") + + // Test if two attempts to 3x replication returns same set of locations + val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) + assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, + "Inserting a 3x replicated block second time gave different locations from the first") + + // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication + val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) + assert( + a2Locs2x.subsetOf(a2Locs3x), + "Inserting a with 2x replication gave locations that are not a subset of locations" + + s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" + ) + + // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication + val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) + assert( + a2Locs3x.subsetOf(a2Locs4x), + "Inserting a with 4x replication gave locations that are not a superset of locations " + + s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" + ) + + // Test if 3x replication of two different blocks gives two different sets of locations + val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) + assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") + } + + test("block replication - replication failures") { + /* + Create a system of three block managers / stores. One of them (say, failableStore) + cannot receive blocks. So attempts to use that as replication target fails. + + +-----------/fails/-----------> failableStore + | + normalStore + | + +-----------/works/-----------> anotherNormalStore + + We are first going to add a normal block manager (i.e. normalStore) and the failable block + manager (i.e. failableStore), and test whether 2x replication fails to create two + copies of a block. Then we are going to add another normal block manager + (i.e., anotherNormalStore), and test that now 2x replication works as the + new store will be used for replication. + */ + + // Add a normal block manager + val store = makeBlockManager(10000, "store") + + // Insert a block with 2x replication and return the number of copies of the block + def replicateAndGetNumCopies(blockId: String): Int = { + store.putSingle(blockId, new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK_2) + val numLocations = master.getLocations(blockId).size + allStores.foreach { _.removeBlock(blockId) } + numLocations + } + + // Add a failable block manager with a mock transfer service that does not + // allow receiving of blocks. So attempts to use it as a replication target will fail. + val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work + when(failableTransfer.hostName).thenReturn("some-hostname") + when(failableTransfer.port).thenReturn(1000) + val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer) + allStores += failableStore // so that this gets stopped after test + assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) + + // Test that 2x replication fails by creating only one copy of the block + assert(replicateAndGetNumCopies("a1") === 1) + + // Add another normal block manager and test that 2x replication works + makeBlockManager(10000, "anotherStore") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a2") === 2) + } + } + + test("block replication - addition and deletion of block managers") { + val blockSize = 1000 + val storeSize = 10000 + val initialStores = (1 to 2).map { i => makeBlockManager(storeSize, s"store$i") } + + // Insert a block with given replication factor and return the number of copies of the block\ + def replicateAndGetNumCopies(blockId: String, replicationFactor: Int): Int = { + val storageLevel = StorageLevel(true, true, false, true, replicationFactor) + initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val numLocations = master.getLocations(blockId).size + allStores.foreach { _.removeBlock(blockId) } + numLocations + } + + // 2x replication should work, 3x replication should only replicate 2x + assert(replicateAndGetNumCopies("a1", 2) === 2) + assert(replicateAndGetNumCopies("a2", 3) === 2) + + // Add another store, 3x replication should work now, 4x replication should only replicate 3x + val newStore1 = makeBlockManager(storeSize, s"newstore1") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a3", 3) === 3) + } + assert(replicateAndGetNumCopies("a4", 4) === 3) + + // Add another store, 4x replication should work now + val newStore2 = makeBlockManager(storeSize, s"newstore2") + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a5", 4) === 4) + } + + // Remove all but the 1st store, 2x replication should fail + (initialStores.tail ++ Seq(newStore1, newStore2)).foreach { + store => + master.removeExecutor(store.blockManagerId.executorId) + store.stop() + } + assert(replicateAndGetNumCopies("a6", 2) === 1) + + // Add new stores, 3x replication should work + val newStores = (3 to 5).map { + i => makeBlockManager(storeSize, s"newstore$i") + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + assert(replicateAndGetNumCopies("a7", 3) === 3) + } + } + + /** + * Test replication of blocks with different storage levels (various combinations of + * memory, disk & serialization). For each storage level, this function tests every store + * whether the block is present and also tests the master whether its knowledge of blocks + * is correct. Then it also drops the block from memory of each store (using LRU) and + * again checks whether the master's knowledge gets updated. + */ + private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { + import org.apache.spark.storage.StorageLevel._ + + assert(maxReplication > 1, + s"Cannot test replication factor $maxReplication") + + // storage levels to test with the given replication factor + + val storeSize = 10000 + val blockSize = 1000 + + // As many stores as the replication factor + val stores = (1 to maxReplication).map { + i => makeBlockManager(storeSize, s"store$i") + } + + storageLevels.foreach { storageLevel => + // Put the block into one of the stores + val blockId = new TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) + stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + + // Assert that master know two locations for the block + val blockLocations = master.getLocations(blockId).map(_.executorId).toSet + assert(blockLocations.size === storageLevel.replication, + s"master did not have ${storageLevel.replication} locations for $blockId") + + // Test state of the stores that contain the block + stores.filter { + testStore => blockLocations.contains(testStore.blockManagerId.executorId) + }.foreach { testStore => + val testStoreName = testStore.blockManagerId.executorId + assert(testStore.getLocal(blockId).isDefined, s"$blockId was not found in $testStoreName") + assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), + s"master does not have status for ${blockId.name} in $testStoreName") + + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) + + // Assert that block status in the master for this store has expected storage level + assert( + blockStatus.storageLevel.useDisk === storageLevel.useDisk && + blockStatus.storageLevel.useMemory === storageLevel.useMemory && + blockStatus.storageLevel.useOffHeap === storageLevel.useOffHeap && + blockStatus.storageLevel.deserialized === storageLevel.deserialized, + s"master does not know correct storage level for ${blockId.name} in $testStoreName") + + // Assert that the block status in the master for this store has correct memory usage info + assert(!blockStatus.storageLevel.useMemory || blockStatus.memSize >= blockSize, + s"master does not know size of ${blockId.name} stored in memory of $testStoreName") + + + // If the block is supposed to be in memory, then drop the copy of the block in + // this store test whether master is updated with zero memory usage this store + if (storageLevel.useMemory) { + // Force the block to be dropped by adding a number of dummy blocks + (1 to 10).foreach { + i => + testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + } + (1 to 10).foreach { + i => testStore.removeBlock(s"dummy-block-$i") + } + + val newBlockStatusOption = master.getBlockStatus(blockId).get(testStore.blockManagerId) + + // Assert that the block status in the master either does not exist (block removed + // from every store) or has zero memory usage for this store + assert( + newBlockStatusOption.isEmpty || newBlockStatusOption.get.memSize === 0, + s"after dropping, master does not know size of ${blockId.name} " + + s"stored in memory of $testStoreName" + ) + } + + // If the block is supposed to be in disk (after dropping or otherwise, then + // test whether master has correct disk usage for this store + if (storageLevel.useDisk) { + assert(master.getBlockStatus(blockId)(testStore.blockManagerId).diskSize >= blockSize, + s"after dropping, master does not know size of ${blockId.name} " + + s"stored in disk of $testStoreName" + ) + } + } + master.removeBlock(blockId) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e251660dae5de..9d96202a3e7ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,8 +21,6 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit -import org.apache.spark.network.nio.NioBlockTransferService - import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import scala.concurrent.duration._ @@ -35,13 +33,13 @@ import akka.util.Timeout import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -189,7 +187,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter store = makeBlockManager(2000, "exec1") store2 = makeBlockManager(2000, "exec2") - val peers = master.getPeers(store.blockManagerId, 1) + val peers = master.getPeers(store.blockManagerId) assert(peers.size === 1, "master did not return the other manager as a peer") assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") @@ -448,7 +446,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val list2DiskGet = store.get("list2disk") assert(list2DiskGet.isDefined, "list2memory expected to be in store") assert(list2DiskGet.get.data.size === 3) - System.out.println(list2DiskGet) // We don't know the exact size of the data on disk, but it should certainly be > 0. assert(list2DiskGet.get.inputMetrics.bytesRead > 0) assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk) From 127e97bee1e6aae7b70263bc5944b7be6f4e6fea Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Thu, 2 Oct 2014 13:52:54 -0700 Subject: [PATCH 18/36] [SPARK-3632] ConnectionManager can run out of receive threads with authentication on If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out. To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed. Author: Thomas Graves Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits: a0a961d [Thomas Graves] give it a type b6bc80b [Thomas Graves] Rework comments d6d4175 [Thomas Graves] update from comments 081b765 [Thomas Graves] cleanup 4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication --- .../org/apache/spark/SecurityManager.scala | 5 +- .../apache/spark/network/nio/Connection.scala | 65 +++++++++++------ .../spark/network/nio/ConnectionManager.scala | 72 +++++-------------- 3 files changed, 63 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 3832a780ec4bc..0e0f1a7b2377e 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to * match up incoming messages with connections waiting for authentication. - * If its acting as a client and trying to send a message to another ConnectionManager, - * it blocks the thread calling sendMessage until the SASL negotiation has occurred. * The ConnectionManager tracks all the sendingConnections using the ConnectionId - * and waits for the response from the server and does the handshake. + * and waits for the response from the server and does the handshake before sending + * the real message. * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 18172d359cb35..f368209980f93 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,23 +20,27 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.LinkedList import org.apache.spark._ -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap} private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, + val securityMgr: SecurityManager) extends Logging { var sparkSaslServer: SparkSaslServer = null var sparkSaslClient: SparkSaslClient = null - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, + securityMgr_ : SecurityManager) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_) + channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), + id_, securityMgr_) } channel.configureBlocking(false) @@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() - /** - * Used to synchronize client requests: client's work-related requests must - * wait until SASL authentication completes. - */ - private val authenticated = new Object() - - def getAuthenticated(): Object = authenticated - def isSaslComplete(): Boolean def resetForceReregister(): Boolean @@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[nio] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId) - extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId, + securityMgr_ : SecurityManager) + extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { def isSaslComplete(): Boolean = { if (sparkSaslClient != null) sparkSaslClient.isComplete() else false } private class Outbox { - val messages = new Queue[Message]() + val messages = new LinkedList[Message]() val defaultChunkSize = 65536 var nextMessageToBeUsed = 0 def addMessage(message: Message) { messages.synchronized { - /* messages += message */ - messages.enqueue(message) + messages.add(message) logDebug("Added [" + message + "] to outbox for sending to " + "[" + getRemoteConnectionManagerId() + "]") } @@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, while (!messages.isEmpty) { /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ /* val message = messages(nextMessageToBeUsed) */ - val message = messages.dequeue() + + val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { + // only allow sending of security messages until sasl is complete + var pos = 0 + var securityMsg: Message = null + while (pos < messages.size() && securityMsg == null) { + if (messages.get(pos).isSecurityNeg) { + securityMsg = messages.remove(pos) + } + pos = pos + 1 + } + // didn't find any security messages and auth isn't completed so return + if (securityMsg == null) return None + securityMsg + } else { + messages.removeFirst() + } + val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { - messages.enqueue(message) + messages.add(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { logDebug( @@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, changeConnectionKeyInterest(DEFAULT_INTEREST) } + def registerAfterAuth(): Unit = { + outbox.synchronized { + needForceReregister = true + } + if (channel.isConnected) { + registerInterest() + } + } + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) @@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, private[spark] class ReceivingConnection( channel_ : SocketChannel, selector_ : Selector, - id_ : ConnectionId) - extends Connection(channel_, selector_, id_) { + id_ : ConnectionId, + securityMgr_ : SecurityManager) + extends Connection(channel_, selector_, id_, securityMgr_) { def isSaslComplete(): Boolean = { if (sparkSaslServer != null) sparkSaslServer.isComplete() else false diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 5aa7e94943561..01cd27a907eea 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import org.apache.spark._ -import org.apache.spark.util.{SystemClock, Utils} +import org.apache.spark.util.Utils private[nio] class ConnectionManager( @@ -65,8 +65,6 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) - // default to 30 second timeout waiting for authentication - private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( @@ -409,7 +407,8 @@ private[nio] class ConnectionManager( while (newChannel != null) { try { val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, + securityManager) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -527,9 +526,8 @@ private[nio] class ConnectionManager( if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll() - } + waitingConn.registerAfterAuth() + wakeupSelector() return } else { var replyToken : Array[Byte] = null @@ -538,9 +536,8 @@ private[nio] class ConnectionManager( if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.getAuthenticated().synchronized { - waitingConn.getAuthenticated().notifyAll() - } + waitingConn.registerAfterAuth() + wakeupSelector() return } val securityMsgResp = SecurityMessage.fromResponse(replyToken, @@ -574,9 +571,11 @@ private[nio] class ConnectionManager( } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId) + logDebug("Server sasl completed: " + connection.connectionId + + " for: " + connectionId) } else { - logDebug("Server sasl not completed: " + connection.connectionId) + logDebug("Server sasl not completed: " + connection.connectionId + + " for: " + connectionId) } if (replyToken != null) { val securityMsgResp = SecurityMessage.fromResponse(replyToken, @@ -723,7 +722,8 @@ private[nio] class ConnectionManager( if (message == null) throw new Exception("Error creating security message") connectionsAwaitingSasl += ((conn.connectionId, conn)) sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + + " to: " + connManagerId) } catch { case e: Exception => { logError("Error getting first response from the SaslClient.", e) @@ -744,7 +744,7 @@ private[nio] class ConnectionManager( val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId) + newConnectionId, securityManager) logInfo("creating new sending connection for security! " + newConnectionId ) registerRequests.enqueue(newConnection) @@ -769,61 +769,23 @@ private[nio] class ConnectionManager( connectionManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId) + newConnectionId, securityManager) logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection } val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) - } + message.senderAddress = id.toSocketAddress() logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + "connectionid: " + connection.connectionId) if (authEnabled) { - // if we aren't authenticated yet lets block the senders until authentication completes - try { - connection.getAuthenticated().synchronized { - val clock = SystemClock - val startTime = clock.getTime() - - while (!connection.isSaslComplete()) { - logDebug("getAuthenticated wait connectionid: " + connection.connectionId) - // have timeout in case remote side never responds - connection.getAuthenticated().wait(500) - if (((clock.getTime() - startTime) >= (authTimeout * 1000)) - && (!connection.isSaslComplete())) { - // took to long to authenticate the connection, something probably went wrong - throw new Exception("Took to long for authentication to " + connectionManagerId + - ", waited " + authTimeout + "seconds, failing.") - } - } - } - } catch { - case e: Exception => logError("Exception while waiting for authentication.", e) - - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(message.id) - s match { - case Some(msgStatus) => { - messageStatuses -= message.id - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.markDone(None) - } - case None => { - logError("no messageStatus for failed message id: " + message.id) - } - } - } - } + checkSendAuthFirst(connectionManagerId, connection) } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) - wakeupSelector() } From 8081ce8bd111923db143abc55bb6ef9793eece35 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 2 Oct 2014 17:47:56 -0700 Subject: [PATCH 19/36] [SPARK-3755][Core] avoid trying privileged port when request a non-privileged port pwendell, ```tryPort``` is not compatible with old code in last PR, this is to fix it. And after discuss with srowen renamed the title to "avoid trying privileged port when request a non-privileged port". Plz refer to the discuss for detail. Author: scwf Closes #2623 from scwf/1-1024 and squashes the following commits: 10a4437 [scwf] add comment de3fd17 [scwf] do not try privileged port when request a non-privileged port 42cb0fa [scwf] make tryPort compatible with old code cb8cc76 [scwf] do not use port 1 - 1024 --- core/src/main/scala/org/apache/spark/util/Utils.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b3025c6ec3364..9399ddab76331 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1439,7 +1439,12 @@ private[spark] object Utils extends Logging { val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" for (offset <- 0 to maxRetries) { // Do not increment port if startPort is 0, which is treated as a special port - val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 + val tryPort = if (startPort == 0) { + startPort + } else { + // If the new port wraps around, do not try a privilege port + ((startPort + offset - 1024) % (65536 - 1024)) + 1024 + } try { val (service, port) = startService(tryPort) logInfo(s"Successfully started service$serviceString on port $port.") From 42d5077fd3f2c37d1cd23f4c81aa89286a74cb40 Mon Sep 17 00:00:00 2001 From: Eric Eijkelenboom Date: Thu, 2 Oct 2014 18:04:38 -0700 Subject: [PATCH 20/36] [DEPLOY] SPARK-3759: Return the exit code of the driver process SparkSubmitDriverBootstrapper.scala now returns the exit code of the driver process, instead of always returning 0. Author: Eric Eijkelenboom Closes #2628 from ericeijkelenboom/master and squashes the following commits: cc4a571 [Eric Eijkelenboom] Return the exit code of the driver process --- .../apache/spark/deploy/SparkSubmitDriverBootstrapper.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 38b5d8e1739d0..a64170a47bc1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -154,7 +154,8 @@ private[spark] object SparkSubmitDriverBootstrapper { process.destroy() } } - process.waitFor() + val returnCode = process.waitFor() + sys.exit(returnCode) } } From 7de4e50a01e90bcf88e0b721b2b15a5162373d56 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 2 Oct 2014 19:32:21 -0700 Subject: [PATCH 21/36] [SQL] Initilize session state before creating CommandProcessor With the old ordering it was possible for commands in the HiveDriver to NPE due to the lack of configuration in the threadlocal session state. Author: Michael Armbrust Closes #2635 from marmbrus/initOrder and squashes the following commits: 9749850 [Michael Armbrust] Initilize session state before creating CommandProcessor --- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fdb56901f9ddb..8bcc098bbb620 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -281,13 +281,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { try { + // Session state must be initilized before the CommandProcessor is created . + SessionState.start(sessionState) + val cmd_trimmed: String = cmd.trim() val tokens: Array[String] = cmd_trimmed.split("\\s+") val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf) - SessionState.start(sessionState) - proc match { case driver: Driver => driver.init() From 1c90347a4bba12df7b76d282a7dbac8e555e049f Mon Sep 17 00:00:00 2001 From: ravipesala Date: Thu, 2 Oct 2014 20:04:33 -0700 Subject: [PATCH 22/36] [SPARK-3654][SQL] Implement all extended HiveQL statements/commands with a separate parser combinator Created separate parser for hql. It preparses the commands like cache,uncache,add jar etc.. and then parses with HiveQl Author: ravipesala Closes #2590 from ravipesala/SPARK-3654 and squashes the following commits: bbca7dd [ravipesala] Fixed code as per admin comments. ae9290a [ravipesala] Fixed style issues as per Admin comments 898ed81 [ravipesala] Removed spaces fb24edf [ravipesala] Updated the code as per admin comments 8947d37 [ravipesala] Removed duplicate code ba26cd1 [ravipesala] Created seperate parser for hql.It pre parses the commands like cache,uncache,add jar etc.. and then parses with HiveQl --- .../spark/sql/hive/ExtendedHiveQlParser.scala | 135 ++++++++++++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 57 ++------ .../spark/sql/hive/CachedTableSuite.scala | 6 + 3 files changed, 154 insertions(+), 44 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala new file mode 100644 index 0000000000000..e7e1cb980c2ae --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.language.implicitConversions +import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.combinator.PackratParsers +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SqlLexical + +/** + * A parser that recognizes all HiveQL constructs together with several Spark SQL specific + * extensions like CACHE TABLE and UNCACHE TABLE. + */ +private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers { + + def apply(input: String): LogicalPlan = { + // Special-case out set commands since the value fields can be + // complex to handle without RegexParsers. Also this approach + // is clearer for the several possible cases of set commands. + if (input.trim.toLowerCase.startsWith("set")) { + input.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } + } else if (input.trim.startsWith("!")) { + ShellCommand(input.drop(1)) + } else { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } + } + } + + protected case class Keyword(str: String) + + protected val CACHE = Keyword("CACHE") + protected val SET = Keyword("SET") + protected val ADD = Keyword("ADD") + protected val JAR = Keyword("JAR") + protected val TABLE = Keyword("TABLE") + protected val AS = Keyword("AS") + protected val UNCACHE = Keyword("UNCACHE") + protected val FILE = Keyword("FILE") + protected val DFS = Keyword("DFS") + protected val SOURCE = Keyword("SOURCE") + + protected implicit def asParser(k: Keyword): Parser[String] = + lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + + protected def allCaseConverse(k: String): Parser[String] = + lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _) + + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword].str) + + override val lexical = new SqlLexical(reservedWords) + + protected lazy val query: Parser[LogicalPlan] = + cache | uncache | addJar | addFile | dfs | source | hiveQl + + protected lazy val hiveQl: Parser[LogicalPlan] = + remainingQuery ^^ { + case r => HiveQl.createPlan(r.trim()) + } + + /** It returns all remaining query */ + protected lazy val remainingQuery: Parser[String] = new Parser[String] { + def apply(in: Input) = + Success( + in.source.subSequence(in.offset, in.source.length).toString, + in.drop(in.source.length())) + } + + /** It returns all query */ + protected lazy val allQuery: Parser[String] = new Parser[String] { + def apply(in: Input) = + Success(in.source.toString, in.drop(in.source.length())) + } + + protected lazy val cache: Parser[LogicalPlan] = + CACHE ~ TABLE ~> ident ~ opt(AS ~> hiveQl) ^^ { + case tableName ~ None => CacheCommand(tableName, true) + case tableName ~ Some(plan) => + CacheTableAsSelectCommand(tableName, plan) + } + + protected lazy val uncache: Parser[LogicalPlan] = + UNCACHE ~ TABLE ~> ident ^^ { + case tableName => CacheCommand(tableName, false) + } + + protected lazy val addJar: Parser[LogicalPlan] = + ADD ~ JAR ~> remainingQuery ^^ { + case rq => AddJar(rq.trim()) + } + + protected lazy val addFile: Parser[LogicalPlan] = + ADD ~ FILE ~> remainingQuery ^^ { + case rq => AddFile(rq.trim()) + } + + protected lazy val dfs: Parser[LogicalPlan] = + DFS ~> allQuery ^^ { + case aq => NativeCommand(aq.trim()) + } + + protected lazy val source: Parser[LogicalPlan] = + SOURCE ~> remainingQuery ^^ { + case rq => SourceCommand(rq.trim()) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4f3f808c93dc8..6bb42eeb0550d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -126,6 +126,9 @@ private[hive] object HiveQl { "TOK_CREATETABLE", "TOK_DESCTABLE" ) ++ nativeCommands + + // It parses hive sql query along with with several Spark SQL specific extensions + protected val hiveSqlParser = new ExtendedHiveQlParser /** * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations @@ -215,40 +218,19 @@ private[hive] object HiveQl { def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = { + def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql) + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String) = { try { - if (sql.trim.toLowerCase.startsWith("set")) { - // Split in two parts since we treat the part before the first "=" - // as key, and the part after as value, which may contain other "=" signs. - sql.trim.drop(3).split("=", 2).map(_.trim) match { - case Array("") => // "set" - SetCommand(None, None) - case Array(key) => // "set key" - SetCommand(Some(key), None) - case Array(key, value) => // "set key=value" - SetCommand(Some(key), Some(value)) - } - } else if (sql.trim.toLowerCase.startsWith("cache table")) { - sql.trim.drop(12).trim.split(" ").toSeq match { - case Seq(tableName) => - CacheCommand(tableName, true) - case Seq(tableName, _, select @ _*) => - CacheTableAsSelectCommand(tableName, createPlan(select.mkString(" ").trim)) - } - } else if (sql.trim.toLowerCase.startsWith("uncache table")) { - CacheCommand(sql.trim.drop(14).trim, false) - } else if (sql.trim.toLowerCase.startsWith("add jar")) { - AddJar(sql.trim.drop(8).trim) - } else if (sql.trim.toLowerCase.startsWith("add file")) { - AddFile(sql.trim.drop(9)) - } else if (sql.trim.toLowerCase.startsWith("dfs")) { + val tree = getAst(sql) + if (nativeCommands contains tree.getText) { NativeCommand(sql) - } else if (sql.trim.startsWith("source")) { - SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) - } else if (sql.trim.startsWith("!")) { - ShellCommand(sql.drop(1)) } else { - createPlan(sql) + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } } } catch { case e: Exception => throw new ParseException(sql, e) @@ -259,19 +241,6 @@ private[hive] object HiveQl { """.stripMargin) } } - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String) = { - val tree = getAst(sql) - if (nativeCommands contains tree.getText) { - NativeCommand(sql) - } else { - nodeToPlan(tree) match { - case NativePlaceholder => NativeCommand(sql) - case other => other - } - } - } def parseDdl(ddl: String): Seq[Attribute] = { val tree = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 188579edd7bdd..b3057cd618c66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -88,4 +88,10 @@ class CachedTableSuite extends HiveComparisonTest { } assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } + + test("'CACHE TABLE tableName AS SELECT ..'") { + TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + TestHive.uncacheTable("testCacheTable") + } } From 2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae Mon Sep 17 00:00:00 2001 From: "qiping.lqp" Date: Fri, 3 Oct 2014 03:26:17 -0700 Subject: [PATCH 23/36] [SPARK-3366][MLLIB]Compute best splits distributively in decision tree Currently, all best splits are computed on the driver, which makes the driver a bottleneck for both communication and computation. This PR fix this problem by computed best splits on executors. Instead of send all aggregate stats to the driver node, we can send aggregate stats for a node to a particular executor, using `reduceByKey` operation, then we can compute best split for this node there. Implementation details: Each node now has a nodeStatsAggregator, which save aggregate stats for all features and bins. First use mapPartition to compute node aggregate stats for all nodes in each partition. Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs and use to `reduceByKey` operation to combine nodeStatsAggregator for the same node. After all stats have been combined, best splits can be computed for each node based on the node aggregate stats. Best split result is collected to driver to construct the decision tree. CC: mengxr manishamde jkbradley, please help me review this, thanks. Author: qiping.lqp Author: chouqin Closes #2595 from chouqin/dt-dist-agg and squashes the following commits: db0d24a [chouqin] fix a minor bug and adjust code a0d9de3 [chouqin] adjust code based on comments 9f201a6 [chouqin] fix bug: statsSize -> allStatsSize a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-dist-agg f13b346 [chouqin] adjust randomforest comments c32636e [chouqin] adjust code based on comments ac6a505 [chouqin] adjust code based on comments 7bbb787 [chouqin] add comments bdd2a63 [qiping.lqp] fix test suite a75df27 [qiping.lqp] fix test suite b5b0bc2 [qiping.lqp] fix style e76414f [qiping.lqp] fix testsuite 748bd45 [qiping.lqp] fix type-mismatch bug 24eacd8 [qiping.lqp] fix type-mismatch bug 5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy 4f56496 [qiping.lqp] fix bug f00fc22 [qiping.lqp] fix bug 532993a [qiping.lqp] Compute best splits distributively in decision tree --- .../spark/mllib/tree/DecisionTree.scala | 140 ++++++--- .../spark/mllib/tree/RandomForest.scala | 5 +- .../mllib/tree/impl/DTStatsAggregator.scala | 292 +++++------------- .../tree/model/InformationGainStats.scala | 11 + .../spark/mllib/tree/RandomForestSuite.scala | 1 + 5 files changed, 182 insertions(+), 267 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b7dc373ebd9cc..b311d10023894 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -23,7 +23,6 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy @@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.SparkContext._ /** @@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging { * for each subset is updated. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). + * each (feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. * @param instanceWeight Weight (importance) of instance in dataset. @@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging { private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int], instanceWeight: Double, @@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging { // Use all features agg.metadata.numFeatures } - val nodeOffset = agg.getNodeOffset(nodeIndex) // Iterate over features. var featureIndexIdx = 0 while (featureIndexIdx < numFeaturesPerNode) { @@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) + agg.getLeftRightFeatureOffsets(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, + agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 @@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging { } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label, - instanceWeight) + agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) } featureIndexIdx += 1 } @@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging { * For each feature, the sufficient statistics of one bin are updated. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (node, feature, bin). + * each (feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int, instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label - val nodeOffset = agg.getNodeOffset(nodeIndex) + // Iterate over features. if (featuresForNode.nonEmpty) { // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.size) { val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) featureIndexIdx += 1 } } else { @@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight) + agg.update(featureIndex, binIndex, label, instanceWeight) featureIndex += 1 } } @@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def binSeqOp( - agg: DTStatsAggregator, - baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = { + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) @@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging { val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode) + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { - mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures, + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, instanceWeight, featuresForNode) } } @@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging { agg } - // Calculate bin aggregates. - timer.start("aggregation") - val binAggregates: DTStatsAggregator = { - val initAgg = if (metadata.subsamplingFeatures) { - new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo) - } else { - new DTStatsAggregatorFixedFeatures(metadata, numNodes) + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group + * @param treeToNodeToIndexInfo + * @return + */ + def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) + : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } } - input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) + Some(mutableNodeToFeatures.toMap) } - timer.stop("aggregation") // Calculate best splits for all nodes in the group timer.start("chooseSplits") + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield an (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + val nodeToBestSplits = + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + }.reduceByKey((a, b) => a.merge(b)) + .map { case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + // find best split for each node + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(aggStats, splits, featuresForNode) + (nodeIndex, (split, stats, predict)) + }.collectAsMap() + + timer.stop("chooseSplits") + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode) + nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. @@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging { } } } - timer.stop("chooseSplits") + } /** @@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. * @param binAggregates Bin statistics. - * @param nodeIndex Index into aggregates for node to split in this group. * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, - nodeIndex: Int, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { - val metadata: DecisionTreeMetadata = binAggregates.metadata - // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => + val (bestSplit, bestSplitStats) = + Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => val featureIndex = if (featuresForNode.nonEmpty) { featuresForNode.get.apply(featureIndexIdx) } else { featureIndexIdx } - val numSplits = metadata.numSplits(featureIndex) - if (metadata.isContinuous(featureIndex)) { + val numSplits = binAggregates.metadata.numSplits(featureIndex) + if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { - binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } // Find best split. @@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (metadata.isUnordered(featureIndex)) { + } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) + binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) - val numBins = metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numBins = binAggregates.metadata.numBins(featureIndex) /* Each bin is one category (feature value). * The bins are ordered based on centroidForCategories, and this ordering determines which @@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging { * * centroidForCategories is a list: (category, centroid) */ - val centroidForCategories = if (metadata.isMulticlass) { + val centroidForCategories = if (binAggregates.metadata.isMulticlass) { // For categorical variables in multiclass classification, // the bins are ordered by the impurity of their corresponding labels. Range(0, numBins).map { case featureValue => @@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging { while (splitIndex < numSplits) { val currentCategory = categoriesSortedByCentroid(splitIndex)._1 val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) splitIndex += 1 } // lastCategory = index of bin with total aggregates for this (node, feature) @@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) + val gainStats = calculateGainForSplit(leftChildStats, + rightChildStats, binAggregates.metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 7fa7725e79e46..fa7a26f17c3ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -171,8 +171,8 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") - DecisionTree.findBestSplits(baggedInput, - metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) timer.stop("findBestSplits") } @@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging { * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * * treeToNodeToIndexInfo holds indices selected features for each node: * treeIndex --> (global) node index --> (node index in group, feature indices). * The (global) node index is the index in the tree; the node index in group is the diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index d49df7a016375..55f422dff0d71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,17 +17,19 @@ package org.apache.spark.mllib.tree.impl -import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.impurity._ + + /** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) + * DecisionTree statistics aggregator for a node. + * This holds a flat array of statistics for a set of (features, bins) * and helps with indexing. * This class is abstract to support learning with and without feature subsampling. */ -private[tree] abstract class DTStatsAggregator( - val metadata: DecisionTreeMetadata) extends Serializable { +private[tree] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. @@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator( /** * Number of elements (Double values) used for the sufficient statistics of each bin. */ - val statsSize: Int = impurityAggregator.statsSize + private val statsSize: Int = impurityAggregator.statsSize + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + private val numBins: Array[Int] = { + if (featureSubset.isDefined) { + featureSubset.get.map(metadata.numBins(_)) + } else { + metadata.numBins + } + } + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } /** * Indicator for each feature of whether that feature is an unordered feature. @@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator( def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) /** - * Total number of elements stored in this aggregator. + * Total number of elements stored in this aggregator */ - def allStatsSize: Int + private val allStatsSize: Int = featureOffsets.last /** - * Get flat array of elements stored in this aggregator. + * Flat array of elements. + * Index for start of stats for a (feature, bin) is: + * index = featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, + * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) + * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ - protected def allStats: Array[Double] + private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. + * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. + * [[getLeftRightFeatureOffsets]]. */ - def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { - impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) + def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } /** - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * Update the stats for a given (feature, bin) for ordered features, using the given label. */ - def update( - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize + def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } - /** - * Pre-compute node offset for use with [[nodeUpdate]]. - */ - def getNodeOffset(nodeIndex: Int): Int - /** * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. - * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + * Update the stats for a given (feature, bin), using the given label. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. + * For unordered features, this is a pre-computed + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. */ - def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, + def featureUpdate( + featureOffset: Int, binIndex: Int, label: Double, - instanceWeight: Double): Unit + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, + label, instanceWeight) + } /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int + def getFeatureOffset(featureIndex: Int): Int = { + require(!isUnordered(featureIndex), + s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + + s" for unordered feature $featureIndex.") + featureOffsets(featureIndex) + } /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ - def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { + def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + + s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + s" but was called for ordered feature $featureIndex.") - val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex) - (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize) - } - - /** - * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin), using the given label. - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. - */ - def nodeFeatureUpdate( - nodeFeatureOffset: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label, - instanceWeight) + val baseOffset = featureOffsets(featureIndex) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } /** - * For a given (node, feature), merge the stats for two bins. - * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset - * from [[getNodeFeatureOffset]]. + * For a given feature, merge the stats for two bins. + * @param featureOffset For ordered features, this is a pre-computed feature offset + * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightNodeFeatureOffsets]]. + * (feature, left/right child) offset from + * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ - def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { - impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, - nodeFeatureOffset + otherBinIndex * statsSize) + def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, + featureOffset + otherBinIndex * statsSize) } /** @@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator( def merge(other: DTStatsAggregator): DTStatsAggregator = { require(allStatsSize == other.allStatsSize, s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." - + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") var i = 0 // TODO: Test BLAS.axpy while (i < allStatsSize) { @@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator( this } } - -/** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) - * and helps with indexing. - * - * This instance of [[DTStatsAggregator]] is used when not subsampling features. - * - * @param numNodes Number of nodes to collect statistics for. - */ -private[tree] class DTStatsAggregatorFixedFeatures( - metadata: DecisionTreeMetadata, - numNodes: Int) extends DTStatsAggregator(metadata) { - - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - * Mapping: featureIndex --> offset - */ - private val featureOffsets: Array[Int] = { - metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. - */ - private val nodeStride: Int = featureOffsets.last - - override val allStatsSize: Int = numNodes * nodeStride - - /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats precede the right child stats - * in the binIndex order. - */ - override protected val allStats: Array[Double] = new Array[Double](allStatsSize) - - override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride - - override def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - nodeIndex * nodeStride + featureOffsets(featureIndex) - } -} - -/** - * DecisionTree statistics aggregator. - * This holds a flat array of statistics for a set of (nodes, features, bins) - * and helps with indexing. - * - * This instance of [[DTStatsAggregator]] is used when subsampling features. - * - * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, - * where nodeIndexInfo stores the index in the group and the - * feature subsets (if using feature subsets). - */ -private[tree] class DTStatsAggregatorSubsampledFeatures( - metadata: DecisionTreeMetadata, - treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) { - - /** - * For each node, offset for each feature for calculating indices into the [[allStats]] array. - * Mapping: nodeIndex --> featureIndex --> offset - */ - private val featureOffsets: Array[Array[Int]] = { - val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum - val offsets = new Array[Array[Int]](numNodes) - treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) => - nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) => - offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_)) - .scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - } - offsets - } - - /** - * For each node, offset for each feature for calculating indices into the [[allStats]] array. - */ - protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _) - - override val allStatsSize: Int = nodeOffsets.last - - /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats precede the right child stats - * in the binIndex order. - */ - override protected val allStats: Array[Double] = new Array[Double](allStatsSize) - - override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex) - - /** - * Faster version of [[update]]. - * Update the stats for a given (node, feature, bin) for ordered features, using the given label. - * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. - * @param featureIndex Index of feature in featuresForNodes(nodeIndex). - * Note: This is NOT the original feature index. - */ - override def nodeUpdate( - nodeOffset: Int, - nodeIndex: Int, - featureIndex: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. - * For ordered features only. - * @param featureIndex Index of feature in featuresForNodes(nodeIndex). - * Note: This is NOT the original feature index. - */ - override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex) - } -} - -private[tree] object DTStatsAggregator extends Serializable { - - /** - * Combines two aggregates (modifying the first) and returns the combination. - */ - def binCombOp( - agg1: DTStatsAggregator, - agg2: DTStatsAggregator): DTStatsAggregator = { - agg1.merge(agg2) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f3e2619bd8ba0..a89e71e115806 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -38,6 +38,17 @@ class InformationGainStats( "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" .format(gain, impurity, leftImpurity, rightImpurity) } + + override def equals(o: Any) = + o match { + case other: InformationGainStats => { + gain == other.gain && + impurity == other.impurity && + leftImpurity == other.leftImpurity && + rightImpurity == other.rightImpurity + } + case _ => false + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 30669fcd1c75b..20d372dc1d3ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { assert(nodesForGroup.size === numTrees, failString) assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree + if (numFeaturesPerNode == numFeatures) { // featureSubset values should all be None assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), From f0811f928e5b608e1a2cba3b6828ba0ed03b701d Mon Sep 17 00:00:00 2001 From: EugenCepoi Date: Fri, 3 Oct 2014 10:03:15 -0700 Subject: [PATCH 24/36] SPARK-2058: Overriding SPARK_HOME/conf with SPARK_CONF_DIR Update of PR #997. With this PR, setting SPARK_CONF_DIR overrides SPARK_HOME/conf (not only spark-defaults.conf and spark-env). Author: EugenCepoi Closes #2481 from EugenCepoi/SPARK-2058 and squashes the following commits: 0bb32c2 [EugenCepoi] use orElse orNull and fixing trailing percent in compute-classpath.cmd 77f35d7 [EugenCepoi] SPARK-2058: Overriding SPARK_HOME/conf with SPARK_CONF_DIR --- bin/compute-classpath.cmd | 8 +++- bin/compute-classpath.sh | 8 +++- .../spark/deploy/SparkSubmitArguments.scala | 42 ++++++++----------- .../spark/deploy/SparkSubmitSuite.scala | 34 ++++++++++++++- docs/configuration.md | 7 ++++ 5 files changed, 71 insertions(+), 28 deletions(-) diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 5ad52452a5c98..9b9e40321ea93 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -36,7 +36,13 @@ rem Load environment variables from conf\spark-env.cmd, if it exists if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%;%FWDIR%conf +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% + +if "x%SPARK_CONF_DIR%"!="x" ( + set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% +) else ( + set CLASSPATH=%CLASSPATH%;%FWDIR%conf +) if exist "%FWDIR%RELEASE" ( for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 0f63e36d8aeca..905bbaf99b374 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -27,8 +27,14 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" . "$FWDIR"/bin/load-spark-env.sh +CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH" + # Build up classpath -CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf" +if [ -n "$SPARK_CONF_DIR" ]; then + CLASSPATH="$CLASSPATH:$SPARK_CONF_DIR" +else + CLASSPATH="$CLASSPATH:$FWDIR/conf" +fi ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 2b72c61cc8177..57b251ff47714 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -29,8 +29,9 @@ import org.apache.spark.util.Utils /** * Parses and encapsulates arguments from the spark-submit script. + * The env argument is used for testing. */ -private[spark] class SparkSubmitArguments(args: Seq[String]) { +private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, String] = sys.env) { var master: String = null var deployMode: String = null var executorMemory: String = null @@ -90,20 +91,12 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { private def mergeSparkProperties(): Unit = { // Use common defaults file, if not specified by user if (propertiesFile == null) { - sys.env.get("SPARK_CONF_DIR").foreach { sparkConfDir => - val sep = File.separator - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } - } - } + val sep = File.separator + val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") + val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - if (propertiesFile == null) { - sys.env.get("SPARK_HOME").foreach { sparkHome => - val sep = File.separator - val defaultPath = s"${sparkHome}${sep}conf${sep}spark-defaults.conf" + confDir.foreach { sparkConfDir => + val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" val file = new File(defaultPath) if (file.exists()) { propertiesFile = file.getAbsolutePath @@ -117,19 +110,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { // Use properties file as fallback for values which have a direct analog to // arguments in this script. - master = Option(master).getOrElse(properties.get("spark.master").orNull) - executorMemory = Option(executorMemory) - .getOrElse(properties.get("spark.executor.memory").orNull) - executorCores = Option(executorCores) - .getOrElse(properties.get("spark.executor.cores").orNull) + master = Option(master).orElse(properties.get("spark.master")).orNull + executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull + executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull totalExecutorCores = Option(totalExecutorCores) - .getOrElse(properties.get("spark.cores.max").orNull) - name = Option(name).getOrElse(properties.get("spark.app.name").orNull) - jars = Option(jars).getOrElse(properties.get("spark.jars").orNull) + .orElse(properties.get("spark.cores.max")) + .orNull + name = Option(name).orElse(properties.get("spark.app.name")).orNull + jars = Option(jars).orElse(properties.get("spark.jars")).orNull // This supports env vars in older versions of Spark - master = Option(master).getOrElse(System.getenv("MASTER")) - deployMode = Option(deployMode).getOrElse(System.getenv("DEPLOY_MODE")) + master = Option(master).orElse(env.get("MASTER")).orNull + deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && primaryResource != null) { @@ -182,7 +174,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { } if (master.startsWith("yarn")) { - val hasHadoopEnv = sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR") + val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR") if (!hasHadoopEnv && !Utils.isTesting) { throw new Exception(s"When running with master '$master' " + "either HADOOP_CONF_DIR or YARN_CONF_DIR must be set in the environment.") diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 0c324d8bdf6a4..4cba90e8f2afe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import java.io.{File, OutputStream, PrintStream} +import java.io._ import scala.collection.mutable.ArrayBuffer @@ -26,6 +26,7 @@ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite import org.scalatest.Matchers +import com.google.common.io.Files class SparkSubmitSuite extends FunSuite with Matchers { def beforeAll() { @@ -306,6 +307,21 @@ class SparkSubmitSuite extends FunSuite with Matchers { runSparkSubmit(args) } + test("SPARK_CONF_DIR overrides spark-defaults.conf") { + forConfDir(Map("spark.executor.memory" -> "2.3g")) { path => + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + unusedJar.toString) + val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) + assert(appArgs.propertiesFile != null) + assert(appArgs.propertiesFile.startsWith(path)) + appArgs.executorMemory should be ("2.3g") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -314,6 +330,22 @@ class SparkSubmitSuite extends FunSuite with Matchers { new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) } + + def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + val tmpDir = Files.createTempDir() + + val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") + val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf)) + for ((key, value) <- defaults) writer.write(s"$key $value\n") + + writer.close() + + try { + f(tmpDir.getAbsolutePath) + } finally { + Utils.deleteRecursively(tmpDir) + } + } } object JarCreationTest { diff --git a/docs/configuration.md b/docs/configuration.md index 316490f0f43fc..a782809a55ec0 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1108,3 +1108,10 @@ compute `SPARK_LOCAL_IP` by looking up the IP of a specific network interface. Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can configure it by adding a `log4j.properties` file in the `conf` directory. One way to start is to copy the existing `log4j.properties.template` located there. + +# Overriding configuration directory + +To specify a different configuration directory other than the default "SPARK_HOME/conf", +you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) +from this directory. + From 9d320e222c221e5bb827cddf01a83e64a16d74ff Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Fri, 3 Oct 2014 10:42:41 -0700 Subject: [PATCH 25/36] [SPARK-3696]Do not override the user-difined conf_dir https://issues.apache.org/jira/browse/SPARK-3696 We see if SPARK_CONF_DIR is already defined before assignment. Author: WangTaoTheTonic Closes #2541 from WangTaoTheTonic/confdir and squashes the following commits: c3f31e0 [WangTaoTheTonic] Do not override the user-difined conf_dir --- sbin/spark-config.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 2718d6cba1c9a..1d154e62ed5b6 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -33,7 +33,7 @@ this="$config_bin/$script" export SPARK_PREFIX="`dirname "$this"`"/.. export SPARK_HOME="${SPARK_PREFIX}" -export SPARK_CONF_DIR="$SPARK_HOME/conf" +export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}" # Add the PySpark classes to the PYTHONPATH: export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" From 22f8e1ee7c4ea7b3bd4c6faaf0fe5b88a134ae12 Mon Sep 17 00:00:00 2001 From: ravipesala Date: Fri, 3 Oct 2014 11:25:18 -0700 Subject: [PATCH 26/36] [SPARK-2693][SQL] Supported for UDAF Hive Aggregates like PERCENTILE Implemented UDAF Hive aggregates by adding wrapper to Spark Hive. Author: ravipesala Closes #2620 from ravipesala/SPARK-2693 and squashes the following commits: a8df326 [ravipesala] Removed resolver from constructor arguments caf25c6 [ravipesala] Fixed style issues 5786200 [ravipesala] Supported for UDAF Hive Aggregates like PERCENTILE --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 46 +++++++++++++++++-- .../sql/hive/execution/HiveUdfSuite.scala | 4 ++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 732e4976f6843..68f93f247d9bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ @@ -57,7 +57,8 @@ private[hive] abstract class HiveFunctionRegistry } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdaf(functionClassName, children) - + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUdaf(functionClassName, children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdtf(functionClassName, Nil, children) } else { @@ -194,6 +195,37 @@ private[hive] case class HiveGenericUdaf( def newInstance() = new HiveUdafFunction(functionClassName, children, this) } +/** It is used as a wrapper for the hive functions which uses UDAF interface */ +private[hive] case class HiveUdaf( + functionClassName: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + type UDFType = UDAF + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction()) + + @transient + protected lazy val objectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(_.dataType).map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + + def newInstance() = + new HiveUdafFunction(functionClassName, children, this, true) +} + /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow @@ -275,14 +307,20 @@ private[hive] case class HiveGenericUdtf( private[hive] case class HiveUdafFunction( functionClassName: String, exprs: Seq[Expression], - base: AggregateExpression) + base: AggregateExpression, + isUDAFBridgeRequired: Boolean = false) extends AggregateFunction with HiveInspectors with HiveFunctionFactory { def this() = this(null, null, null) - private val resolver = createFunction[AbstractGenericUDAFResolver]() + private val resolver = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(createFunction[UDAF]()) + } else { + createFunction[AbstractGenericUDAFResolver]() + } private val inspectors = exprs.map(_.dataType).map(toInspector).toArray diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index cc125d539c3c2..e4324e9528f9b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -79,6 +79,10 @@ class HiveUdfSuite extends HiveComparisonTest { sql("SELECT testUdf(pair) FROM hiveUdfTestTable") sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") } + + test("SPARK-2693 udaf aggregates test") { + assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first) + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { From fbe8e9856b23262193105e7bf86075f516f0db25 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Oct 2014 11:36:24 -0700 Subject: [PATCH 27/36] [SPARK-2778] [yarn] Add workaround for race in MiniYARNCluster. Sometimes the cluster's start() method returns before the configuration having been updated, which is done by ClientRMService in, I assume, a separate thread (otherwise there would be no race). That can cause tests to fail if the old configuration data is read, since it will contain the wrong RM address. Author: Marcelo Vanzin Closes #2605 from vanzin/SPARK-2778 and squashes the following commits: 8d02ce0 [Marcelo Vanzin] Minor cleanup. 5bebee7 [Marcelo Vanzin] [SPARK-2778] [yarn] Add workaround for race in MiniYARNCluster. --- .../spark/deploy/yarn/YarnClusterSuite.scala | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 4b6635679f053..a826b2a78a8f5 100644 --- a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.io.File +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -32,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { +class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the Yarn containers, so that their output is collected // by Yarn instead of trying to overwrite unit-tests.log. @@ -66,7 +67,33 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) yarnCluster.init(new YarnConfiguration()) yarnCluster.start() - yarnCluster.getConfig().foreach { e => + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + config.foreach { e => sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) } @@ -86,13 +113,13 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { super.afterAll() } - ignore("run Spark in yarn-client mode") { + test("run Spark in yarn-client mode") { var result = File.createTempFile("result", null, tempDir) YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) checkResult(result) } - ignore("run Spark in yarn-cluster mode") { + test("run Spark in yarn-cluster mode") { val main = YarnClusterDriver.getClass.getName().stripSuffix("$") var result = File.createTempFile("result", null, tempDir) From bec0d0eaa33811fde72b84f7d53a6f6031e7b5d3 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 3 Oct 2014 12:26:02 -0700 Subject: [PATCH 28/36] [SPARK-3007][SQL] Adds dynamic partitioning support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #2226 was reverted because it broke Jenkins builds for unknown reason. This debugging PR aims to fix the Jenkins build. This PR also fixes two bugs: 1. Compression configurations in `InsertIntoHiveTable` are disabled by mistake The `FileSinkDesc` object passed to the writer container doesn't have compression related configurations. These configurations are not taken care of until `saveAsHiveFile` is called. This PR moves compression code forward, right after instantiation of the `FileSinkDesc` object. 1. `PreInsertionCasts` doesn't take table partitions into account In `castChildOutput`, `table.attributes` only contains non-partition columns, thus for partitioned table `childOutputDataTypes` never equals to `tableOutputDataTypes`. This results funny analyzed plan like this: ``` == Analyzed Logical Plan == InsertIntoTable Map(partcol1 -> None, partcol2 -> None), false MetastoreRelation default, dynamic_part_table, None Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] ... (repeats 99 times) ... Project [c_0#1164,c_1#1165,c_2#1166] Project [c_0#1164,c_1#1165,c_2#1166] Project [1 AS c_0#1164,1 AS c_1#1165,1 AS c_2#1166] Filter (key#1170 = 150) MetastoreRelation default, src, None ``` Awful though this logical plan looks, it's harmless because all projects will be eliminated by optimizer. Guess that's why this issue hasn't been caught before. Author: Cheng Lian Author: baishuo(白硕) Author: baishuo Closes #2616 from liancheng/dp-fix and squashes the following commits: 21935b6 [Cheng Lian] Adds back deleted trailing space f471c4b [Cheng Lian] PreInsertionCasts should take table partitions into account a132c80 [Cheng Lian] Fixes output compression 9c6eb2d [Cheng Lian] Adds tests to verify dynamic partitioning folder layout 0eed349 [Cheng Lian] Addresses @yhuai's comments 26632c3 [Cheng Lian] Adds more tests 9227181 [Cheng Lian] Minor refactoring c47470e [Cheng Lian] Refactors InsertIntoHiveTable to a Command 6fb16d7 [Cheng Lian] Fixes typo in test name, regenerated golden answer files d53daa5 [Cheng Lian] Refactors dynamic partitioning support b821611 [baishuo] pass check style 997c990 [baishuo] use HiveConf.DEFAULTPARTITIONNAME to replace hive.exec.default.partition.name 761ecf2 [baishuo] modify according micheal's advice 207c6ac [baishuo] modify for some bad indentation caea6fb [baishuo] modify code to pass scala style checks b660e74 [baishuo] delete a empty else branch cd822f0 [baishuo] do a little modify 8e7268c [baishuo] update file after test 3f91665 [baishuo(白硕)] Update Cast.scala 8ad173c [baishuo(白硕)] Update InsertIntoHiveTable.scala 051ba91 [baishuo(白硕)] Update Cast.scala d452eb3 [baishuo(白硕)] Update HiveQuerySuite.scala 37c603b [baishuo(白硕)] Update InsertIntoHiveTable.scala 98cfb1f [baishuo(白硕)] Update HiveCompatibilitySuite.scala 6af73f4 [baishuo(白硕)] Update InsertIntoHiveTable.scala adf02f1 [baishuo(白硕)] Update InsertIntoHiveTable.scala 1867e23 [baishuo(白硕)] Update SparkHadoopWriter.scala 6bb5880 [baishuo(白硕)] Update HiveQl.scala --- .../execution/HiveCompatibilitySuite.scala | 17 ++ .../org/apache/spark/SparkHadoopWriter.scala | 195 ---------------- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../org/apache/spark/sql/hive/HiveQl.scala | 5 - .../hive/execution/InsertIntoHiveTable.scala | 218 ++++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 217 +++++++++++++++++ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 + ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +++++++- 15 files changed, 450 insertions(+), 306 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 556c984ad392b..35e9c9939d4b7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", + "dynamic_partition_skip_default", + "infer_bucket_sort_dyn_part", + "load_dyn_part1", + "load_dyn_part2", + "load_dyn_part3", + "load_dyn_part4", + "load_dyn_part5", + "load_dyn_part6", + "load_dyn_part7", + "load_dyn_part8", + "load_dyn_part9", + "load_dyn_part10", + "load_dyn_part11", + "load_dyn_part12", + "load_dyn_part13", + "load_dyn_part14", + "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala deleted file mode 100644 index ab7862f4f9e06..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.Writable - -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveHadoopWriter( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private var format: HiveOutputFormat[AnyRef, Writable] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val extension = Utilities.getFileExtension( - conf.value, - fileSinkConf.getCompressed, - getOutputFormat()) - - val outputName = "part-" + numfmt.format(splitID) + extension - val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) - - getOutputCommitter().setupTask(getTaskContext()) - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - path, - null) - } - - def write(value: Writable) { - if (writer != null) { - writer.write(value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - } - format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - taskContext - } - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveHadoopWriter { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 06b1446ccbd39..989a9784a438d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -144,7 +144,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val childOutputDataTypes = child.output.map(_.dataType) // Only check attributes, not partitionKeys since they are always strings. // TODO: Fully support inserting into partitioned tables. - val tableOutputDataTypes = table.attributes.map(_.dataType) + val tableOutputDataTypes = + table.attributes.map(_.dataType) ++ table.partitionKeys.map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { p diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6bb42eeb0550d..32c9175f181bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -806,11 +806,6 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - if (partitionKeys.values.exists(p => p.isEmpty)) { - throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + - s"dynamic partitioning.") - } - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index a284a91a91e31..16a8c782acdfa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import java.util.{HashMap => JHashMap} - import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector -import org.apache.hadoop.io.Writable +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} +import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} +import org.apache.spark.sql.hive._ +import org.apache.spark.{SerializableWritable, SparkException, TaskContext} /** * :: DeveloperApi :: @@ -51,7 +49,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode { + extends UnaryNode with Command { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -101,66 +99,61 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Writable], + rdd: RDD[Row], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) - if (isCompressed) { - // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", - // and "mapred.output.compression.type" have no impact on ORC because it uses table properties - // to store compression information. - conf.set("mapred.output.compress", "true") - fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + conf: SerializableWritable[JobConf], + writerContainer: SparkHiveWriterContainer) { + assert(valueClass != null, "Output value class not set") + conf.value.setOutputValueClass(valueClass) + + val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName + assert(outputFileFormatClassName != null, "Output format class not set") + conf.value.set("mapred.output.format.class", outputFileFormatClassName) + conf.value.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf.value, + SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + + // Note that this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) + val writer = writerContainer.getLocalFileWriter(row) + writer.write(serializer.serialize(outputData, standardOI)) } - writer.close() - writer.commit() + writerContainer.close() } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() } - override def execute() = result - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -168,50 +161,69 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + val isCompressed = sc.hiveconf.getBoolean( + ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) + if (isCompressed) { + // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", + // and "mapred.output.compression.type" have no impact on ORC because it uses table properties + // to store compression information. + sc.hiveconf.set("mapred.output.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(sc.hiveconf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(sc.hiveconf.get("mapred.output.compression.type")) + } - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + val numDynamicPartitions = partition.values.count(_.isEmpty) + val numStaticPartitions = partition.values.count(_.nonEmpty) + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } + + // All partition column names in the format of "//..." + val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull + + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } - serializer.serialize(outputData, standardOI) + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. + val jobConfSer = new SerializableWritable(jobConf) + + val writerContainer = if (numDynamicPartitions > 0) { + val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + } else { + new SparkHiveWriterContainer(jobConf, fileSinkConf) + } + + saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) + val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -220,10 +232,6 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -231,14 +239,26 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + if (numDynamicPartitions > 0) { + db.loadDynamicPartitions( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir + ) + } else { + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } else { db.loadTable( outputPath, @@ -251,6 +271,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala new file mode 100644 index 0000000000000..ac5c7a8220296 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import org.apache.spark.sql.Row +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + protected val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private lazy val committer = conf.value.getOutputCommitter + @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient private lazy val outputFormat = + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + + def driverSideSetup() { + setIDs(0, 0, 0) + setConfParams() + committer.setupJob(jobContext) + } + + def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { + setIDs(jobId, splitId, attemptId) + setConfParams() + committer.setupTask(taskContext) + initWriters() + } + + protected def getOutputName: String = { + val numberFormat = NumberFormat.getInstance() + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) + "part-" + numberFormat.format(splitID) + extension + } + + def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + commit() + } + + def commitJob() { + committer.commitJob(jobContext) + } + + protected def initWriters() { + // NOTE this method is executed at the executor side. + // For Hive tables without partitions or with only static partitions, only 1 writer is needed. + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), + Reporter.NULL) + } + + protected def commit() { + if (committer.needsTaskCommit(taskContext)) { + try { + committer.commitTask(taskContext) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + committer.abortTask(taskContext) + throw e + } + } else { + logInfo("No need to commit output of task: " + taID.value) + } + } + + // ********* Private Functions ********* + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveWriterContainer { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } +} + +private[spark] class SparkHiveDynamicPartitionWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + dynamicPartColNames: Array[String]) + extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + + private val defaultPartName = jobConf.get( + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + + @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ + + override protected def initWriters(): Unit = { + // NOTE: This method is executed at the executor side. + // Actual writers are created for each dynamic partition on the fly. + writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + } + + override def close(): Unit = { + writers.values.foreach(_.close(false)) + commit() + } + + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + val dynamicPartPath = dynamicPartColNames + .zip(row.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = if (rawVal == null) null else String.valueOf(rawVal) + s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}" + } + .mkString + + def newWriter = { + val newFileSinkDesc = new FileSinkDesc( + fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getTableInfo, + fileSinkConf.getCompressed) + newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) + newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) + + val path = { + val outputPath = FileOutputFormat.getOutputPath(conf.value) + assert(outputPath != null, "Undefined job output-path") + val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) + new Path(workPath, getOutputName) + } + + HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + newFileSinkDesc, + path, + Reporter.NULL) + } + + writers.getOrElseUpdate(dynamicPartPath, newWriter) + } +} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index f5868bff22f13..2e282a9ade40c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -380,7 +383,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -568,6 +571,91 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) + createQueryTest("dynamic_partition", + """ + |DROP TABLE IF EXISTS dynamic_part_table; + |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); + | + |SET hive.exec.dynamic.partition.mode=nonstrict; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, NULL FROM src WHERE key=150; + | + |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, NULL FROM src WHERE key=150; + | + |DROP TABLE IF EXISTS dynamic_part_table; + """.stripMargin) + + test("Dynamic partition folder layout") { + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + val data = Map( + Seq("1", "1") -> 1, + Seq("1", "NULL") -> 2, + Seq("NULL", "1") -> 3, + Seq("NULL", "NULL") -> 4) + + data.foreach { case (parts, value) => + sql( + s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 + """.stripMargin) + + val partFolder = Seq("partcol1", "partcol2") + .zip(parts) + .map { case (k, v) => + if (v == "NULL") { + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + } else { + s"$k=$v" + } + } + .mkString("/") + + // Loads partition data to a temporary table to verify contents + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + } + } + + test("Partition spec validation") { + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") + + // Should throw when using strict dynamic partition mode without any static partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + // Should throw when a static partition appears after a dynamic partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + } + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -625,27 +713,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + collectResults(sql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). From 6a1d48f4f02c4498b64439c3dd5f671286a90e30 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 3 Oct 2014 12:34:27 -0700 Subject: [PATCH 29/36] [SPARK-3212][SQL] Use logical plan matching instead of temporary tables for table caching _Also addresses: SPARK-1671, SPARK-1379 and SPARK-3641_ This PR introduces a new trait, `CacheManger`, which replaces the previous temporary table based caching system. Instead of creating a temporary table that shadows an existing table with and equivalent cached representation, the cached manager maintains a separate list of logical plans and their cached data. After optimization, this list is searched for any matching plan fragments. When a matching plan fragment is found it is replaced with the cached data. There are several advantages to this approach: - Calling .cache() on a SchemaRDD now works as you would expect, and uses the more efficient columnar representation. - Its now possible to provide a list of temporary tables, without having to decide if a given table is actually just a cached persistent table. (To be done in a follow-up PR) - In some cases it is possible that cached data will be used, even if a cached table was not explicitly requested. This is because we now look at the logical structure instead of the table name. - We now correctly invalidate when data is inserted into a hive table. Author: Michael Armbrust Closes #2501 from marmbrus/caching and squashes the following commits: 63fbc2c [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching. 0ea889e [Michael Armbrust] Address comments. 1e23287 [Michael Armbrust] Add support for cache invalidation for hive inserts. 65ed04a [Michael Armbrust] fix tests. bdf9a3f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into caching b4b77f2 [Michael Armbrust] Address comments 6923c9d [Michael Armbrust] More comments / tests 80f26ac [Michael Armbrust] First draft of improved semantics for Spark SQL caching. --- .../sql/catalyst/analysis/Analyzer.scala | 3 + .../expressions/namedExpressions.scala | 4 +- .../catalyst/plans/logical/LogicalPlan.scala | 42 ++++++ .../catalyst/plans/logical/TestRelation.scala | 6 + .../plans/logical/basicOperators.scala | 4 +- .../sql/catalyst/plans/SameResultSuite.scala | 62 ++++++++ .../org/apache/spark/sql/CacheManager.scala | 139 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 51 +------ .../org/apache/spark/sql/SchemaRDD.scala | 23 ++- .../org/apache/spark/sql/SchemaRDDLike.scala | 5 +- .../spark/sql/api/java/JavaSQLContext.scala | 10 +- .../columnar/InMemoryColumnarTableScan.scala | 28 +++- .../spark/sql/execution/ExistingRDD.scala | 119 +++++++++++++++ .../spark/sql/execution/SparkPlan.scala | 33 ----- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../spark/sql/execution/basicOperators.scala | 39 ----- .../apache/spark/sql/CachedTableSuite.scala | 103 +++++++------ .../columnar/InMemoryColumnarQuerySuite.scala | 7 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 7 +- .../spark/sql/hive/HiveStrategies.scala | 6 +- .../org/apache/spark/sql/hive/TestHive.scala | 5 +- .../hive/execution/InsertIntoHiveTable.scala | 3 + .../spark/sql/hive/CachedTableSuite.scala | 100 ++++++++----- 23 files changed, 567 insertions(+), 241 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 71810b798bd04..fe83eb12502dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -93,6 +93,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object ResolveRelations extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ InsertIntoTable(UnresolvedRelation(databaseName, name, alias), _, _, _) => + i.copy( + table = EliminateAnalysisOperators(catalog.lookupRelation(databaseName, name, alias))) case UnresolvedRelation(databaseName, name, alias) => catalog.lookupRelation(databaseName, name, alias) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 59fb0311a9c44..e5a958d599393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -62,7 +62,7 @@ abstract class Attribute extends NamedExpression { def withName(newName: String): Attribute def toAttribute = this - def newInstance: Attribute + def newInstance(): Attribute } @@ -131,7 +131,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea h } - override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 28d863e58beca..4f8ad8a7e0223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.catalyst.trees @@ -72,6 +73,47 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = !children.exists(!_.resolved) + /** + * Returns true when the given logical plan will return the same results as this logical plan. + * + * Since its likely undecideable to generally determine if two given plans will produce the same + * results, it is okay for this function to return false, even if the results are actually + * the same. Such behavior will not affect correctness, only the application of performance + * enhancements like caching. However, it is not acceptable to return true if the results could + * possibly be different. + * + * By default this function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. Logical operators that + * can do better should override this function. + */ + def sameResult(plan: LogicalPlan): Boolean = { + plan.getClass == this.getClass && + plan.children.size == children.size && { + logDebug(s"[${cleanArgs.mkString(", ")}] == [${plan.cleanArgs.mkString(", ")}]") + cleanArgs == plan.cleanArgs + } && + (plan.children, children).zipped.forall(_ sameResult _) + } + + /** Args that have cleaned such that differences in expression id should not affect equality */ + protected lazy val cleanArgs: Seq[Any] = { + val input = children.flatMap(_.output) + productIterator.map { + // Children are checked using sameResult above. + case tn: TreeNode[_] if children contains tn => null + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case s: Option[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case s: Seq[_] => s.map { + case e: Expression => BindReferences.bindReference(e, input, allowFailures = true) + case other => other + } + case other => other + }.toSeq + } + /** * Optionally resolves the given string to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index f8fe558511bfd..19769986ef58c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -41,4 +41,10 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) } override protected def stringArgs = Iterator(output) + + override def sameResult(plan: LogicalPlan): Boolean = plan match { + case LocalRelation(otherOutput, otherData) => + otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 391508279bb80..f8e9930ac270d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -105,8 +105,8 @@ case class InsertIntoTable( child: LogicalPlan, overwrite: Boolean) extends LogicalPlan { - // The table being inserted into is a child for the purposes of transformations. - override def children = table :: child :: Nil + + override def children = child :: Nil override def output = child.output override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala new file mode 100644 index 0000000000000..e8a793d107451 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.util._ + +/** + * Provides helper methods for comparing plans. + */ +class SameResultSuite extends FunSuite { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + + def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true) = { + val aAnalyzed = a.analyze + val bAnalyzed = b.analyze + + if (aAnalyzed.sameResult(bAnalyzed) != result) { + val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n") + fail(s"Plans should return sameResult = $result\n$comparison") + } + } + + test("relations") { + assertSameResult(testRelation, testRelation2) + } + + test("projections") { + assertSameResult(testRelation.select('a), testRelation2.select('a)) + assertSameResult(testRelation.select('b), testRelation2.select('b)) + assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) + + assertSameResult(testRelation, testRelation2.select('a), false) + assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), false) + } + + test("filters") { + assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala new file mode 100644 index 0000000000000..aebdbb68e49b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.locks.ReentrantReadWriteLock + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY + +/** Holds a cached logical plan and its data */ +private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) + +/** + * Provides support in a SQLContext for caching query results and automatically using these cached + * results when subsequent queries are executed. Data is cached using byte buffers stored in an + * InMemoryRelation. This relation is automatically substituted query plans that return the + * `sameResult` as the originally cached query. + */ +private[sql] trait CacheManager { + self: SQLContext => + + @transient + private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + + @transient + private val cacheLock = new ReentrantReadWriteLock + + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = cacheQuery(table(tableName)) + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val lock = cacheLock.readLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val lock = cacheLock.writeLock() + lock.lock() + try f finally { + lock.unlock() + } + } + + private[sql] def clearCache(): Unit = writeLock { + cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.clear() + } + + /** Caches the data produced by the logical representation of the given schema rdd. */ + private[sql] def cacheQuery( + query: SchemaRDD, + storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + if (lookupCachedData(planToCache).nonEmpty) { + logWarning("Asked to cache already cached data.") + } else { + cachedData += + CachedData( + planToCache, + InMemoryRelation( + useCompression, columnBatchSize, storageLevel, query.queryExecution.executedPlan)) + } + } + + /** Removes the data for the given SchemaRDD from the cache */ + private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = false): Unit = writeLock { + val planToCache = query.queryExecution.optimizedPlan + val dataIndex = cachedData.indexWhere(_.plan.sameResult(planToCache)) + + if (dataIndex < 0) { + throw new IllegalArgumentException(s"Table $query is not cached.") + } + + cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData.remove(dataIndex) + } + + + /** Optionally returns cached data for the given SchemaRDD */ + private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + lookupCachedData(query.queryExecution.optimizedPlan) + } + + /** Optionally returns cached data for the given LogicalPlan. */ + private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { + cachedData.find(_.plan.sameResult(plan)) + } + + /** Replaces segments of the given logical plan with cached versions where possible. */ + private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + plan transformDown { + case currentFragment => + lookupCachedData(currentFragment) + .map(_.cachedRepresentation.withOutput(currentFragment.output)) + .getOrElse(currentFragment) + } + } + + /** + * Invalidates the cache of any data that contains `plan`. Note that it is possible that this + * function will over invalidate. + */ + private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { + cachedData.foreach { + case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => + data.cachedRepresentation.recache() + case _ => + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a42bedbe6c04e..7a55c5bf97a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -50,6 +50,7 @@ import org.apache.spark.{Logging, SparkContext} class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging with SQLConf + with CacheManager with ExpressionConversions with UDFRegistration with Serializable { @@ -96,7 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))(self)) + new SchemaRDD(this, + LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) } /** @@ -133,7 +135,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD))(self) + val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) new SchemaRDD(this, logicalPlan) } @@ -272,45 +274,6 @@ class SQLContext(@transient val sparkContext: SparkContext) def table(tableName: String): SchemaRDD = new SchemaRDD(this, catalog.lookupRelation(None, tableName)) - /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = { - val currentTable = table(tableName).queryExecution.analyzed - val asInMemoryRelation = currentTable match { - case _: InMemoryRelation => - currentTable - - case _ => - InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) - } - - catalog.registerTable(None, tableName, asInMemoryRelation) - } - - /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = { - table(tableName).queryExecution.analyzed match { - // This is kind of a hack to make sure that if this was just an RDD registered as a table, - // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) - case inMem: InMemoryRelation => - inMem.cachedColumnBuffers.unpersist() - catalog.unregisterTable(None, tableName) - case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") - } - } - - /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = { - val relation = table(tableName).queryExecution.analyzed - relation match { - case _: InMemoryRelation => true - case _ => false - } - } - protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -401,10 +364,12 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) + lazy val withCachedData = useCachedData(optimizedPlan) + // TODO: Don't just pick the first one... lazy val sparkPlan = { SparkPlan.currentContext.set(self) - planner(optimizedPlan).next() + planner(withCachedData).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. @@ -526,6 +491,6 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self)) + new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 3b873f7c62cb6..594bf8ffc20e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.storage.StorageLevel + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -32,7 +34,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.api.java.JavaRDD /** @@ -442,8 +444,7 @@ class SchemaRDD( */ private def applySchema(rdd: RDD[Row]): SchemaRDD = { new SchemaRDD(sqlContext, - SparkLogicalPlan( - ExistingRdd(queryExecution.analyzed.output.map(_.newInstance), rdd))(sqlContext)) + LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) } // ======================================================================= @@ -497,4 +498,20 @@ class SchemaRDD( override def subtract(other: RDD[Row], p: Partitioner) (implicit ord: Ordering[Row] = null): SchemaRDD = applySchema(super.subtract(other, p)(ord)) + + /** Overridden cache function will always use the in-memory columnar caching. */ + override def cache(): this.type = { + sqlContext.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheQuery(this, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.uncacheQuery(this, blocking) + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index e52eeb3e1c47e..25ba7d88ba538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.execution.LogicalRDD /** * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) @@ -55,8 +55,7 @@ private[sql] trait SchemaRDDLike { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => - queryExecution.toRdd - SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => baseLogicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 150ff8a42063d..c006c4330ff66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -100,7 +100,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow } } - new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) } /** @@ -114,7 +114,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val scalaRowRDD = rowRDD.rdd.map(r => r.row) val scalaSchema = asScalaDataType(schema).asInstanceOf[SStructType] val logicalPlan = - SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(scalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -151,7 +151,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -167,7 +167,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = - SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD))(sqlContext) + LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext) new JavaSchemaRDD(sqlContext, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 8a3612cdf19be..cec82a7f2df94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,10 +27,15 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.storage.StorageLevel private[sql] object InMemoryRelation { - def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, child)() + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)() } private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) @@ -39,6 +44,7 @@ private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, + storageLevel: StorageLevel, child: SparkPlan) (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { @@ -51,6 +57,16 @@ private[sql] case class InMemoryRelation( // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { + buildBuffers() + } + + def recache() = { + _cachedColumnBuffers.unpersist() + _cachedColumnBuffers = null + buildBuffers() + } + + private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { @@ -80,12 +96,17 @@ private[sql] case class InMemoryRelation( def hasNext = rowIterator.hasNext } - }.cache() + }.persist(storageLevel) cached.setName(child.toString) _cachedColumnBuffers = cached } + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + InMemoryRelation( + newOutput, useCompression, batchSize, storageLevel, child)(_cachedColumnBuffers) + } + override def children = Seq.empty override def newInstance() = { @@ -93,6 +114,7 @@ private[sql] case class InMemoryRelation( output.map(_.newInstance), useCompression, batchSize, + storageLevel, child)( _cachedColumnBuffers).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala new file mode 100644 index 0000000000000..2ddf513b6fc98 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +object RDDConversions { + def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) + + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + i += 1 + } + + mutableRow + } + } + } + } + + /* + def toLogicalPlan[A <: Product : TypeTag](productRdd: RDD[A]): LogicalPlan = { + LogicalRDD(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) + } + */ +} + +case class LogicalRDD(output: Seq[Attribute], rdd: RDD[Row])(sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def children = Nil + + def newInstance() = + LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] + + override def sameResult(plan: LogicalPlan) = plan match { + case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} + +case class PhysicalRDD(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + override def execute() = rdd +} + +@deprecated("Use LogicalRDD", "1.2.0") +case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) + extends LogicalPlan with MultiInstanceRelation { + + def output = alreadyPlanned.output + override def children = Nil + + override final def newInstance(): this.type = { + SparkLogicalPlan( + alreadyPlanned match { + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case _ => sys.error("Multiple instance of the same relation detected.") + })(sqlContext).asInstanceOf[this.type] + } + + override def sameResult(plan: LogicalPlan) = plan match { + case SparkLogicalPlan(ExistingRdd(_, rdd)) => + rdd.id == alreadyPlanned.asInstanceOf[ExistingRdd].rdd.id + case _ => false + } + + @transient override lazy val statistics = Statistics( + // TODO: Instead of returning a default value here, find a way to return a meaningful size + // estimate for RDDs. See PR 1238 for more discussions. + sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8913985b028..b1a7948b66cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -126,39 +126,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } -/** - * :: DeveloperApi :: - * Allows already planned SparkQueries to be linked into logical query plans. - * - * Note that in general it is not valid to use this class to link multiple copies of the same - * physical operator into the same query plan as this violates the uniqueness of expression ids. - * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just - * replace the output attributes with new copies of themselves without breaking any attribute - * linking. - */ -@DeveloperApi -case class SparkLogicalPlan(alreadyPlanned: SparkPlan)(@transient sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - def output = alreadyPlanned.output - override def children = Nil - - override final def newInstance(): this.type = { - SparkLogicalPlan( - alreadyPlanned match { - case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) - case _ => sys.error("Multiple instance of the same relation detected.") - })(sqlContext).asInstanceOf[this.type] - } - - @transient override lazy val statistics = Statistics( - // TODO: Instead of returning a default value here, find a way to return a meaningful size - // estimate for RDDs. See PR 1238 for more discussions. - sizeInBytes = BigInt(sqlContext.defaultSizeInBytes) - ) - -} - private[sql] trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { self: Product => } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 45687d960404c..cf93d5ad7b503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -272,10 +272,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case SparkLogicalPlan(alreadyPlanned) => alreadyPlanned :: Nil case logical.LocalRelation(output, data) => - ExistingRdd( + PhysicalRDD( output, - ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => @@ -287,12 +288,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil case logical.NoRelation => - execution.ExistingRdd(Nil, singleRowRdd) :: Nil + execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil case e @ EvaluatePython(udf, child) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil - case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index cac376608be29..977f3c9f32096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -210,45 +210,6 @@ case class Sort( override def output = child.output } -/** - * :: DeveloperApi :: - */ -@DeveloperApi -object ExistingRdd { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { - data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) - i += 1 - } - - mutableRow - } - } - } - } - - def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = { - ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) - } -} - -/** - * :: DeveloperApi :: - */ -@DeveloperApi -case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - override def execute() = rdd -} - /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 591592841e9fe..957388e99bd85 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -20,13 +20,30 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ case class BigData(s: String) class CachedTableSuite extends QueryTest { + import TestSQLContext._ TestData // Load test tables. + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } + test("too big for memory") { val data = "*" * 10000 sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") @@ -35,19 +52,21 @@ class CachedTableSuite extends QueryTest { uncacheTable("bigData") } + test("calling .cache() should use inmemory columnar caching") { + table("testData").cache() + + assertCached(table("testData")) + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) cacheTable("testData") - table("testData").queryExecution.analyzed match { - case _: InMemoryRelation => - case _ => - fail("testData should be cached") - } + assertCached(table("testData")) cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => @@ -55,81 +74,69 @@ class CachedTableSuite extends QueryTest { } test("read from cached table and uncache") { - TestSQLContext.cacheTable("testData") + cacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } + assertCached(table("testData")) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") checkAnswer( - TestSQLContext.table("testData"), + table("testData"), testData.collect().toSeq ) - TestSQLContext.table("testData").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } + assertCached(table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") - TestSQLContext.cacheTable("selectStar") - TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() - TestSQLContext.uncacheTable("selectStar") + sql("SELECT * FROM testData").registerTempTable("selectStar") + cacheTable("selectStar") + sql("SELECT * FROM selectStar WHERE key = 1").collect() + uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - TestSQLContext.cacheTable("testData") + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() + cacheTable("testData") checkAnswer( - TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), + sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - TestSQLContext.uncacheTable("testData") + uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { - TestSQLContext.sql("CACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'testData' should be cached") - } - assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached") + sql("CACHE TABLE testData") + assertCached(table("testData")) - TestSQLContext.sql("UNCACHE TABLE testData") - TestSQLContext.table("testData").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached") - case _ => // Found evidence of uncaching - } - assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached") + assert(isCached("testData"), "Table 'testData' should be cached") + + sql("UNCACHE TABLE testData") + assertCached(table("testData"), 0) + assert(!isCached("testData"), "Table 'testData' should not be cached") } test("CACHE TABLE tableName AS SELECT Star Table") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect() - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + sql("SELECT * FROM testCacheTable WHERE key = 1").collect() + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } test("'CACHE TABLE tableName AS SELECT ..'") { - TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestSQLContext.uncacheTable("testCacheTable") + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + assert(isCached("testCacheTable"), "Table 'testCacheTable' should be cached") + uncacheTable("testCacheTable") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index c1278248ef655..9775dd26b7773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.{QueryTest, TestData} +import org.apache.spark.storage.StorageLevel.MEMORY_ONLY class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -27,7 +28,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) } @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -51,7 +52,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, 5, plan) + val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 989a9784a438d..cc0605b0adb35 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -133,11 +133,6 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => castChildOutput(p, table, child) - - case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), _, child, _) => - castChildOutput(p, table, child) } def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) = { @@ -306,7 +301,7 @@ private[hive] case class MetastoreRelation HiveMetastoreTypes.toDataType(f.getType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true - )(qualifiers = tableName +: alias.toSeq) + )(qualifiers = Seq(alias.getOrElse(tableName))) } // Must be a stable value since new attributes are born here. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 8ac17f37201a8..508d8239c7628 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.types.StringType -import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} import org.apache.spark.sql.hive import org.apache.spark.sql.hive.execution._ @@ -161,10 +160,7 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil - case logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), partition, child, overwrite) => - InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.CreateTableAsSelect(database, tableName, child) => val query = planLater(child) CreateTableAsSelect( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 4a999b98ad92b..c0e69393cc2e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -353,7 +353,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { - // Marks the table as loaded first to prevent infite mutually recursive table loading. + // Marks the table as loaded first to prevent infinite mutually recursive table loading. loadedTables += name logInfo(s"Loading test table $name") val createCmds = @@ -383,6 +383,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } + clearCache() loadedTables.clear() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") @@ -428,7 +429,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logError(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError("FATAL ERROR: Failed to reset TestDB state.", e) // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 16a8c782acdfa..f8b4e898ec41d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -267,6 +267,9 @@ case class InsertIntoHiveTable( holdDDLTime) } + // Invalidate the cache. + sqlContext.invalidateCache(table) + // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index b3057cd618c66..158cfb5bbee7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,22 +17,60 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.{QueryTest, SchemaRDD} import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} -import org.apache.spark.sql.hive.execution.HiveComparisonTest import org.apache.spark.sql.hive.test.TestHive -class CachedTableSuite extends HiveComparisonTest { +class CachedTableSuite extends QueryTest { import TestHive._ - TestHive.loadTestTable("src") + /** + * Throws a test failed exception when the number of cached tables differs from the expected + * number. + */ + def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + val planWithCaching = query.queryExecution.withCachedData + val cachedData = planWithCaching collect { + case cached: InMemoryRelation => cached + } + + if (cachedData.size != numCachedTables) { + fail( + s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + + planWithCaching) + } + } test("cache table") { - TestHive.cacheTable("src") + val preCacheResults = sql("SELECT * FROM src").collect().toSeq + + cacheTable("src") + assertCached(sql("SELECT * FROM src")) + + checkAnswer( + sql("SELECT * FROM src"), + preCacheResults) + + uncacheTable("src") + assertCached(sql("SELECT * FROM src"), 0) } - createQueryTest("read from cached table", - "SELECT * FROM src LIMIT 1", reset = false) + test("cache invalidation") { + sql("CREATE TABLE cachedTable(key INT, value STRING)") + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + cacheTable("cachedTable") + checkAnswer(sql("SELECT * FROM cachedTable"), table("src").collect().toSeq) + + sql("INSERT INTO TABLE cachedTable SELECT * FROM src") + checkAnswer( + sql("SELECT * FROM cachedTable"), + table("src").collect().toSeq ++ table("src").collect().toSeq) + + sql("DROP TABLE cachedTable") + } test("Drop cached table") { sql("CREATE TABLE test(a INT)") @@ -48,25 +86,6 @@ class CachedTableSuite extends HiveComparisonTest { sql("DROP TABLE IF EXISTS nonexistantTable") } - test("check that table is cached and uncache") { - TestHive.table("src").queryExecution.analyzed match { - case _ : InMemoryRelation => // Found evidence of caching - case noCache => fail(s"No cache node found in plan $noCache") - } - TestHive.uncacheTable("src") - } - - createQueryTest("read from uncached table", - "SELECT * FROM src LIMIT 1", reset = false) - - test("make sure table is uncached") { - TestHive.table("src").queryExecution.analyzed match { - case cachePlan: InMemoryRelation => - fail(s"Table still cached after uncache: $cachePlan") - case noCache => // Table uncached successfully - } - } - test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { TestHive.uncacheTable("src") @@ -75,23 +94,24 @@ class CachedTableSuite extends HiveComparisonTest { test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { TestHive.sql("CACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => // Found evidence of caching - case _ => fail(s"Table 'src' should be cached") - } + assertCached(table("src")) assert(TestHive.isCached("src"), "Table 'src' should be cached") TestHive.sql("UNCACHE TABLE src") - TestHive.table("src").queryExecution.executedPlan match { - case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached") - case _ => // Found evidence of uncaching - } + assertCached(table("src"), 0) assert(!TestHive.isCached("src"), "Table 'src' should not be cached") } - - test("'CACHE TABLE tableName AS SELECT ..'") { - TestHive.sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assert(TestHive.isCached("testCacheTable"), "Table 'testCacheTable' should be cached") - TestHive.uncacheTable("testCacheTable") - } + + test("CACHE TABLE AS SELECT") { + assertCached(sql("SELECT * FROM src"), 0) + sql("CACHE TABLE test AS SELECT key FROM src") + + checkAnswer( + sql("SELECT * FROM test"), + sql("SELECT key FROM src").collect().toSeq) + + assertCached(sql("SELECT * FROM test")) + + assertCached(sql("SELECT * FROM test JOIN test"), 2) + } } From a8c52d5343e19731909e73db5de151a324d31cd5 Mon Sep 17 00:00:00 2001 From: Brenden Matthews Date: Fri, 3 Oct 2014 12:58:04 -0700 Subject: [PATCH 30/36] [SPARK-3535][Mesos] Fix resource handling. Author: Brenden Matthews Closes #2401 from brndnmtthws/master and squashes the following commits: 4abaa5d [Brenden Matthews] [SPARK-3535][Mesos] Fix resource handling. --- .../mesos/CoarseMesosSchedulerBackend.scala | 7 ++-- .../scheduler/cluster/mesos/MemoryUtils.scala | 35 +++++++++++++++++++ .../cluster/mesos/MesosSchedulerBackend.scala | 34 ++++++++++++++---- docs/configuration.md | 11 ++++++ 4 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 64568409dbafd..3161f1ee9fa8a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -198,7 +198,9 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 && + if (totalCoresAcquired < maxCores && + mem >= MemoryUtils.calculateTotalMemory(sc) && + cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { // Launch an executor on the slave @@ -214,7 +216,8 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", sc.executorMemory)) + .addResources(createResource("mem", + MemoryUtils.calculateTotalMemory(sc))) .build() d.launchTasks( Collections.singleton(offer.getId), Collections.singletonList(task), filters) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala new file mode 100644 index 0000000000000..5101ec8352e79 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MemoryUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.spark.SparkContext + +private[spark] object MemoryUtils { + // These defaults copied from YARN + val OVERHEAD_FRACTION = 1.07 + val OVERHEAD_MINIMUM = 384 + + def calculateTotalMemory(sc: SparkContext) = { + math.max( + sc.conf.getOption("spark.mesos.executor.memoryOverhead") + .getOrElse(OVERHEAD_MINIMUM.toString) + .toInt + sc.executorMemory, + OVERHEAD_FRACTION * sc.executorMemory + ) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index a9ef126f5de0e..4c49aa074ebc0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -124,15 +124,24 @@ private[spark] class MesosSchedulerBackend( command.setValue("cd %s*; ./sbin/spark-executor".format(basename)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } + val cpus = Resource.newBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder() + .setValue(scheduler.CPUS_PER_TASK).build()) + .build() val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build()) + .setScalar( + Value.Scalar.newBuilder() + .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() ExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) + .addResources(cpus) .addResources(memory) .build() } @@ -204,18 +213,31 @@ private[spark] class MesosSchedulerBackend( val offerableWorkers = new ArrayBuffer[WorkerOffer] val offerableIndices = new HashMap[String, Int] - def enoughMemory(o: Offer) = { + def sufficientOffer(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId) + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) } - for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices.put(offer.getSlaveId.getValue, index) + for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { + val slaveId = offer.getSlaveId.getValue + offerableIndices.put(slaveId, index) + val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { + getResource(offer.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + getResource(offer.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK + } offerableWorkers += new WorkerOffer( offer.getSlaveId.getValue, offer.getHostname, - getResource(offer.getResourcesList, "cpus").toInt) + cpus) } // Call into the TaskSchedulerImpl diff --git a/docs/configuration.md b/docs/configuration.md index a782809a55ec0..1c33855365170 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -253,6 +253,17 @@ Apart from these, the following properties are also available, and may be useful spark.executor.uri. + + spark.mesos.executor.memoryOverhead + executor memory * 0.07, with minimum of 384 + + This value is an additive for spark.executor.memory, specified in MiB, + which is used to calculate the total Mesos task memory. A value of 384 + implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum + overhead. The final overhead will be the larger of either + `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`. + + #### Shuffle Behavior From 358d7ffd01b4a3fbae313890522cf662c71af6e5 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Fri, 3 Oct 2014 13:09:48 -0700 Subject: [PATCH 31/36] [SPARK-3775] Not suitable error message in spark-shell.cmd Modified some sentence of error message in bin\*.cmd. Author: Masayoshi TSUZUKI Closes #2640 from tsudukim/feature/SPARK-3775 and squashes the following commits: 3458afb [Masayoshi TSUZUKI] [SPARK-3775] Not suitable error message in spark-shell.cmd --- bin/pyspark2.cmd | 2 +- bin/run-example2.cmd | 2 +- bin/spark-class | 2 +- bin/spark-class2.cmd | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 2c4b08af8d4c3..a0e66abcc26c9 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -33,7 +33,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop* ) if [%FOUND_JAR%] == [0] ( echo Failed to find Spark assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) :skip_build_test diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index b29bf90c64e90..b49d0dcb4ff2d 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -52,7 +52,7 @@ if exist "%FWDIR%RELEASE" ( ) if "x%SPARK_EXAMPLES_JAR%"=="x" ( echo Failed to find Spark examples assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) diff --git a/bin/spark-class b/bin/spark-class index 613dc9c4566f2..e8201c18d52de 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -146,7 +146,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 - echo "You need to build spark before running $1." 1>&2 + echo "You need to build Spark before running $1." 1>&2 exit 1 fi CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 6c5672819172b..da46543647efd 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -104,7 +104,7 @@ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop* ) if "%FOUND_JAR%"=="0" ( echo Failed to find Spark assembly JAR. - echo You need to build Spark with sbt\sbt assembly before running this program. + echo You need to build Spark before running this program. goto exit ) :skip_build_test From e5566e05b1ac99aa6caf1701e47ebcdb68a002c6 Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Fri, 3 Oct 2014 13:12:37 -0700 Subject: [PATCH 32/36] [SPARK-3774] typo comment in bin/utils.sh Modified the comment of bin/utils.sh. Author: Masayoshi TSUZUKI Closes #2639 from tsudukim/feature/SPARK-3774 and squashes the following commits: 707b779 [Masayoshi TSUZUKI] [SPARK-3774] typo comment in bin/utils.sh --- bin/utils.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/utils.sh b/bin/utils.sh index 0804b1ed9f231..22ea2b9a6d586 100755 --- a/bin/utils.sh +++ b/bin/utils.sh @@ -17,7 +17,7 @@ # limitations under the License. # -# Gather all all spark-submit options into SUBMISSION_OPTS +# Gather all spark-submit options into SUBMISSION_OPTS function gatherSparkSubmitOpts() { if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then From 30abef154768e5c4c6062f3341933dbda990f6cc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 3 Oct 2014 13:18:35 -0700 Subject: [PATCH 33/36] [SPARK-3606] [yarn] Correctly configure AmIpFilter for Yarn HA. The existing code only considered one of the RMs when running in Yarn HA mode, so it was possible to get errors if the active RM was not registered in the filter. The change makes use of a new API added to Yarn that returns all proxy addresses, and falls back to the old behavior if the API is not present. While there, I also made a change to look for the scheme (http or https) being used by Yarn when building the proxy URIs. Since, in the case of multiple RMs, Yarn uses commas as a separator, it was not possible anymore to use spark.filter.params to propagate this information (which used commas to delimit different config params). Instead, I added a new param (spark.filter.jsonParams) which expects a JSON string containing a map with the config data. I chose not to add it to the documentation at this point since I don't believe users will use it directly. Author: Marcelo Vanzin Closes #2469 from vanzin/SPARK-3606 and squashes the following commits: aeb458a [Marcelo Vanzin] Undelete needed import. 65e400d [Marcelo Vanzin] Remove unused import. d121883 [Marcelo Vanzin] Use separate config for each param instead of json. 04bc156 [Marcelo Vanzin] Review feedback. 4d4d6b9 [Marcelo Vanzin] [SPARK-3606] [yarn] Correctly configure AmIpFilter for Yarn HA. --- .../cluster/CoarseGrainedClusterMessage.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 8 ++++-- .../org/apache/spark/ui/JettyUtils.scala | 14 ++++++---- .../spark/deploy/yarn/YarnRMClientImpl.scala | 8 ++++-- .../spark/deploy/yarn/ApplicationMaster.scala | 12 +++----- .../spark/deploy/yarn/YarnRMClient.scala | 4 +-- .../spark/deploy/yarn/YarnRMClientImpl.scala | 28 ++++++++++++++++++- 7 files changed, 53 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 6abf6d930c155..fb8160abc59db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,7 +66,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage - case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String) + case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 89089e7d6f8a8..59aed6b72fe42 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -275,15 +275,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } // Add filters to the SparkUI - def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) { + def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) { if (proxyBase != null && proxyBase.nonEmpty) { System.setProperty("spark.ui.proxyBase", proxyBase) } - if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) { + val hasFilter = (filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty) + if (hasFilter) { logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") conf.set("spark.ui.filters", filterName) - conf.set(s"spark.$filterName.params", filterParams) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6b4689291097f..2a27d49d2de05 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -21,9 +21,7 @@ import java.net.{InetSocketAddress, URL} import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} -import scala.annotation.tailrec import scala.language.implicitConversions -import scala.util.{Failure, Success, Try} import scala.xml.Node import org.eclipse.jetty.server.Server @@ -147,15 +145,19 @@ private[spark] object JettyUtils extends Logging { val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // Get any parameters for each filter - val paramName = "spark." + filter + ".params" - val params = conf.get(paramName, "").split(',').map(_.trim()).toSet - params.foreach { - case param : String => + conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach { + param: String => if (!param.isEmpty) { val parts = param.split("=") if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) } } + + val prefix = s"spark.$filter.param." + conf.getAll + .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) } + .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index acf26505e4cf9..9bd1719cb1808 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -76,8 +76,12 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC resourceManager.finishApplicationMaster(finishReq) } - override def getProxyHostAndPort(conf: YarnConfiguration) = - YarnConfiguration.getProxyHostAndPort(conf) + override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val uriBase = "http://" + proxy + proxyBase + Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase) + } override def getMaxRegAttempts(conf: YarnConfiguration) = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b51daeb437516..caceef5d4b5b0 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -368,18 +368,14 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter() = { - val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" - val proxy = client.getProxyHostAndPort(yarnConf) - val parts = proxy.split(":") val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) - val uriBase = "http://" + proxy + proxyBase - val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val params = client.getAmIpFilterParams(yarnConf, proxyBase) if (isDriver) { System.setProperty("spark.ui.filters", amFilter) - System.setProperty(s"spark.$amFilter.params", params) + params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) } } else { - actor ! AddWebUIFilter(amFilter, params, proxyBase) + actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index ed65e56b3e413..943dc56202a37 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -59,8 +59,8 @@ trait YarnRMClient { /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId - /** Returns the RM's proxy host and port. */ - def getProxyHostAndPort(conf: YarnConfiguration): String + /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ + def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String): Map[String, String] /** Returns the maximum number of attempts to register the AM. */ def getMaxRegAttempts(conf: YarnConfiguration): Int diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index 54bc6b14c44ce..b581790e158ac 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -17,8 +17,13 @@ package org.apache.spark.deploy.yarn +import java.util.{List => JList} + import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ +import scala.util._ +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ @@ -69,7 +74,28 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC appAttemptId } - override def getProxyHostAndPort(conf: YarnConfiguration) = WebAppUtils.getProxyHostAndPort(conf) + override def getAmIpFilterParams(conf: YarnConfiguration, proxyBase: String) = { + // Figure out which scheme Yarn is using. Note the method seems to have been added after 2.2, + // so not all stable releases have it. + val prefix = Try(classOf[WebAppUtils].getMethod("getHttpSchemePrefix", classOf[Configuration]) + .invoke(null, conf).asInstanceOf[String]).getOrElse("http://") + + // If running a new enough Yarn, use the HA-aware API for retrieving the RM addresses. + try { + val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", + classOf[Configuration]) + val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] + val hosts = proxies.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) + } catch { + case e: NoSuchMethodException => + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val uriBase = prefix + proxy + proxyBase + Map("PROXY_HOST" -> parts(0), "PROXY_URI_BASE" -> uriBase) + } + } override def getMaxRegAttempts(conf: YarnConfiguration) = conf.getInt(YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) From 1eb8389cb4ad40a405149b16e2719e12367d667a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 3 Oct 2014 13:26:30 -0700 Subject: [PATCH 34/36] [SPARK-3763] The example of building with sbt should be "sbt assembly" instead of "sbt compile" In building-spark.md, there are some examples for making assembled package with maven but the example for building with sbt is only about for compiling. Author: Kousuke Saruta Closes #2627 from sarutak/SPARK-3763 and squashes the following commits: fadb990 [Kousuke Saruta] Modified the example to build with sbt in building-spark.md --- docs/building-spark.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 2378092d4a1a8..901c157162fee 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -169,7 +169,7 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - sbt/sbt -Pyarn -Phadoop-2.3 compile + sbt/sbt -Pyarn -Phadoop-2.3 assembly # Speeding up Compilation with Zinc From 79e45c9323455a51f25ed9acd0edd8682b4bbb88 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 3 Oct 2014 13:48:56 -0700 Subject: [PATCH 35/36] [SPARK-3377] [SPARK-3610] Metrics can be accidentally aggregated / History server log name should not be based on user input This PR is another solution for #2250 I'm using codahale base MetricsSystem of Spark with JMX or Graphite, and I saw following 2 problems. (1) When applications which have same spark.app.name run on cluster at the same time, some metrics names are mixed. For instance, if 2+ application is running on the cluster at the same time, each application emits the same named metric like "SparkPi.DAGScheduler.stage.failedStages" and Graphite cannot distinguish the metrics is for which application. (2) When 2+ executors run on the same machine, JVM metrics of each executors are mixed. For instance, 2+ executors running on the same node can emit the same named metric "jvm.memory" and Graphite cannot distinguish the metrics is from which application. And there is an similar issue. The directory for event logs is named using application name. Application name is defined by user and the name can includes illegal character for path names. Further more, the directory name consists of application name and System.currentTimeMillis even though each application has unique Application ID so if we run jobs which have same name, it's difficult to identify which directory is for which application. Closes #2250 Closes #1067 Author: Kousuke Saruta Closes #2432 from sarutak/metrics-structure-improvement2 and squashes the following commits: 3288b2b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 39169e4 [Kousuke Saruta] Fixed style 6570494 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 817e4f0 [Kousuke Saruta] Simplified MetricsSystem#buildRegistryName 67fa5eb [Kousuke Saruta] Unified MetricsSystem#registerSources and registerSinks in start 10be654 [Kousuke Saruta] Fixed style. 990c078 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 f0c7fba [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 59cc2cd [Kousuke Saruta] Modified SparkContextSchedulerCreationSuite f9b6fb3 [Kousuke Saruta] Modified style. 2cf8a0f [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 389090d [Kousuke Saruta] Replaced taskScheduler.applicationId() with getApplicationId in SparkContext#postApplicationStart ff45c89 [Kousuke Saruta] Added some test cases to MetricsSystemSuite 69c46a6 [Kousuke Saruta] Added warning logging logic to MetricsSystem#buildRegistryName 5cca0d2 [Kousuke Saruta] Added Javadoc comment to SparkContext#getApplicationId 16a9f01 [Kousuke Saruta] Added data types to be returned to some methods 6434b06 [Kousuke Saruta] Reverted changes related to ApplicationId 0413b90 [Kousuke Saruta] Deleted ApplicationId.java and ApplicationIdSuite.java a42300c [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 0fc1b09 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 42bea55 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 248935d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 f6af132 [Kousuke Saruta] Modified SchedulerBackend and TaskScheduler to return System.currentTimeMillis as an unique Application Id 1b8b53e [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 97cb85c [Kousuke Saruta] Modified confliction of MimExcludes 2cdd009 [Kousuke Saruta] Modified defailt implementation of applicationId 9aadb0b [Kousuke Saruta] Modified NetworkReceiverSuite to ensure "executor.start()" is finished in test "network receiver life cycle" 3011efc [Kousuke Saruta] Added ApplicationIdSuite.scala d009c55 [Kousuke Saruta] Modified ApplicationId#equals to compare appIds dfc83fd [Kousuke Saruta] Modified ApplicationId to implement Serializable 9ff4851 [Kousuke Saruta] Modified MimaExcludes.scala to ignore createTaskScheduler method in SparkContext 4567ffc [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 6a91b14 [Kousuke Saruta] Modified SparkContextSchedulerCreationSuite, ExecutorRunnerTest and EventLoggingListenerSuite 0325caf [Kousuke Saruta] Added ApplicationId.scala 0a2fc14 [Kousuke Saruta] Modified style eabda80 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 0f890e6 [Kousuke Saruta] Modified SparkDeploySchedulerBackend and Master to pass baseLogDir instead f eventLogDir bcf25bf [Kousuke Saruta] Modified directory name for EventLogs 28d4d93 [Kousuke Saruta] Modified SparkContext and EventLoggingListener so that the directory for EventLogs is named same for Application ID 203634e [Kousuke Saruta] Modified comment in SchedulerBackend#applicationId and TaskScheduler#applicationId 424fea4 [Kousuke Saruta] Modified the subclasses of TaskScheduler and SchedulerBackend so that they can return non-optional Unique Application ID b311806 [Kousuke Saruta] Swapped last 2 arguments passed to CoarseGrainedExecutorBackend 8a2b6ec [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 086ee25 [Kousuke Saruta] Merge branch 'metrics-structure-improvement2' of github.com:sarutak/spark into metrics-structure-improvement2 e705386 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 36d2f7a [Kousuke Saruta] Added warning message for the situation we cannot get application id for the prefix for the name of metrics eea6e19 [Kousuke Saruta] Modified CoarseGrainedMesosSchedulerBackend and MesosSchedulerBackend so that we can get Application ID c229fbe [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 e719c39 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 4a93c7f [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement2 4776f9e [Kousuke Saruta] Modified MetricsSystemSuite.scala efcb6e1 [Kousuke Saruta] Modified to add application id to metrics name 2ec848a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 3ea7896 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement ead8966 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 08e627e [Kousuke Saruta] Revert "tmp" 7b67f5a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 45bd33d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 93e263a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 848819c [Kousuke Saruta] Merge branch 'metrics-structure-improvement' of github.com:sarutak/spark into metrics-structure-improvement 912a637 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement e4a4593 [Kousuke Saruta] tmp 3e098d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 4603a39 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement fa7175b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 15f88a3 [Kousuke Saruta] Modified MetricsSystem#buildRegistryName because conf.get does not return null when correspondin entry is absent 6f7dcd4 [Kousuke Saruta] Modified constructor of DAGSchedulerSource and BlockManagerSource because the instance of SparkContext is no longer used 6fc5560 [Kousuke Saruta] Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource 4e057c9 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into metrics-structure-improvement 85ffc02 [Kousuke Saruta] Revert "Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource" 868e326 [Kousuke Saruta] Modified MetricsSystem to set registry name with unique application-id and driver/executor-id 71609f5 [Kousuke Saruta] Modified sourceName of ExecutorSource, DAGSchedulerSource and BlockManagerSource 55debab [Kousuke Saruta] Modified SparkContext and Executor to set spark.executor.id to identifiers 4180993 [Kousuke Saruta] Modified SparkContext to retain spark.unique.app.name property in SparkConf --- .../scala/org/apache/spark/SparkContext.scala | 52 ++++--- .../scala/org/apache/spark/SparkEnv.scala | 8 +- .../apache/spark/deploy/master/Master.scala | 12 +- .../CoarseGrainedExecutorBackend.scala | 16 ++- .../org/apache/spark/executor/Executor.scala | 1 + .../spark/executor/ExecutorSource.scala | 3 +- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../apache/spark/metrics/MetricsSystem.scala | 40 +++++- .../spark/scheduler/DAGSchedulerSource.scala | 4 +- .../scheduler/EventLoggingListener.scala | 33 +++-- .../spark/scheduler/SchedulerBackend.scala | 8 +- .../spark/scheduler/TaskScheduler.scala | 8 +- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 9 +- .../mesos/CoarseMesosSchedulerBackend.scala | 11 +- .../cluster/mesos/MesosSchedulerBackend.scala | 13 +- .../spark/scheduler/local/LocalBackend.scala | 3 + .../spark/storage/BlockManagerSource.scala | 4 +- .../spark/metrics/MetricsSystemSuite.scala | 128 +++++++++++++++++- .../scheduler/EventLoggingListenerSuite.scala | 14 +- .../spark/scheduler/ReplayListenerSuite.scala | 3 +- .../streaming/NetworkReceiverSuite.scala | 14 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 +- .../deploy/yarn/ExecutorRunnableUtil.scala | 2 + .../spark/deploy/yarn/YarnAllocator.scala | 2 + .../cluster/YarnClientSchedulerBackend.scala | 6 +- .../cluster/YarnClusterSchedulerBackend.scala | 9 +- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 +- .../deploy/yarn/YarnAllocationHandler.scala | 2 +- 29 files changed, 331 insertions(+), 85 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 979d178c35969..97109b9f41b60 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -187,6 +187,15 @@ class SparkContext(config: SparkConf) extends Logging { val master = conf.get("spark.master") val appName = conf.get("spark.app.name") + private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) + private[spark] val eventLogDir: Option[String] = { + if (isEventLogEnabled) { + Some(conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR).stripSuffix("/")) + } else { + None + } + } + // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() @@ -200,6 +209,7 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] val listenerBus = new LiveListenerBus // Create the Spark execution environment (cache, map output tracker, etc) + conf.set("spark.executor.id", "driver") private[spark] val env = SparkEnv.create( conf, "", @@ -232,19 +242,6 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (conf.getBoolean("spark.eventLog.enabled", false)) { - val logger = new EventLoggingListener(appName, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } - - // At this point, all relevant SparkListeners have been registered, so begin releasing events - listenerBus.start() - val startTime = System.currentTimeMillis() // Add each JAR given through the constructor @@ -309,6 +306,29 @@ class SparkContext(config: SparkConf) extends Logging { // constructor taskScheduler.start() + val applicationId: String = taskScheduler.applicationId() + conf.set("spark.app.id", applicationId) + + val metricsSystem = env.metricsSystem + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + + // Optionally log Spark events + private[spark] val eventLogger: Option[EventLoggingListener] = { + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else None + } + + // At this point, all relevant SparkListeners have been registered, so begin releasing events + listenerBus.start() + private[spark] val cleaner: Option[ContextCleaner] = { if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { Some(new ContextCleaner(this)) @@ -411,8 +431,8 @@ class SparkContext(config: SparkConf) extends Logging { // Post init taskScheduler.postStartHook() - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) + private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) + private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) private def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) @@ -1278,7 +1298,7 @@ class SparkContext(config: SparkConf) extends Logging { private def postApplicationStart() { // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). - listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(), + listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), startTime, sparkUser)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 009ed64775844..72cac42cd2b2b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -259,11 +259,15 @@ object SparkEnv extends Logging { } val metricsSystem = if (isDriver) { + // Don't start metrics system right now for Driver. + // We need to wait for the task scheduler to give us an app ID. + // Then we can start the metrics system. MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf, securityManager) + val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager) + ms.start() + ms } - metricsSystem.start() // Set the sparkFiles directory, used when downloading dependencies. In local mode, // this is a temporary directory; in distributed mode, this is the executor's current working diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 432b552c58cd8..f98b531316a3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -33,8 +33,8 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, - SparkHadoopUtil} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, + ExecutorState, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState @@ -693,16 +693,18 @@ private[spark] class Master( app.desc.appUiUrl = notFoundBasePath return false } - val fileSystem = Utils.getHadoopFileSystem(eventLogDir, + + val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id) + val fileSystem = Utils.getHadoopFileSystem(appEventLogDir, SparkHadoopUtil.get.newConfiguration(conf)) - val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) + val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem) val eventLogPaths = eventLogInfo.logPaths val compressionCodec = eventLogInfo.compressionCodec if (eventLogPaths.isEmpty) { // Event logging is enabled for this application, but no event logs are found val title = s"Application history not found (${app.id})" - var msg = s"No event logs found for application $appName in $eventLogDir." + var msg = s"No event logs found for application $appName in $appEventLogDir." logWarning(msg) msg += " Did you specify the correct logging directory?" msg = URLEncoder.encode(msg, "UTF-8") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 13af5b6f5812d..06061edfc0844 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -106,6 +106,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { executorId: String, hostname: String, cores: Int, + appId: String, workerUrl: Option[String]) { SignalLogger.register(log) @@ -122,7 +123,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] + val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create a new ActorSystem using driver's Spark properties to run the backend. @@ -144,16 +146,16 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { def main(args: Array[String]) { args.length match { - case x if x < 4 => + case x if x < 5 => System.err.println( // Worker url is used in spark standalone mode to enforce fate-sharing with worker "Usage: CoarseGrainedExecutorBackend " + - " []") + " [] ") System.exit(1) - case 4 => - run(args(0), args(1), args(2), args(3).toInt, None) - case x if x > 4 => - run(args(0), args(1), args(2), args(3).toInt, Some(args(4))) + case 5 => + run(args(0), args(1), args(2), args(3).toInt, args(4), None) + case x if x > 5 => + run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5))) } } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d7211ae465902..9bbfcdc4a0b6e 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -74,6 +74,7 @@ private[spark] class Executor( val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) + conf.set("spark.executor.id", "executor." + executorId) private val env = { if (!isLocal) { val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index d6721586566c2..c4d73622c4727 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -37,8 +37,7 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) override val metricRegistry = new MetricRegistry() - // TODO: It would be nice to pass the application name here - override val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor" // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index a42c8b43bbf7f..bca0b152268ad 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -52,7 +52,8 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver - val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) + val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ + Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) executor = new Executor( executorInfo.getExecutorId.getValue, slaveInfo.getHostname, diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fd316a89a1a10..5dd67b0cbf683 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -83,10 +83,10 @@ private[spark] class MetricsSystem private ( def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array()) metricsConfig.initialize() - registerSources() - registerSinks() def start() { + registerSources() + registerSinks() sinks.foreach(_.start) } @@ -98,10 +98,39 @@ private[spark] class MetricsSystem private ( sinks.foreach(_.report()) } + /** + * Build a name that uniquely identifies each metric source. + * The name is structured as follows: ... + * If either ID is not available, this defaults to just using . + * + * @param source Metric source to be named by this method. + * @return An unique metric name for each combination of + * application, executor/driver and metric source. + */ + def buildRegistryName(source: Source): String = { + val appId = conf.getOption("spark.app.id") + val executorId = conf.getOption("spark.executor.id") + val defaultName = MetricRegistry.name(source.sourceName) + + if (instance == "driver" || instance == "executor") { + if (appId.isDefined && executorId.isDefined) { + MetricRegistry.name(appId.get, executorId.get, source.sourceName) + } else { + // Only Driver and Executor are set spark.app.id and spark.executor.id. + // For instance, Master and Worker are not related to a specific application. + val warningMsg = s"Using default name $defaultName for source because %s is not set." + if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) } + if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) } + defaultName + } + } else { defaultName } + } + def registerSource(source: Source) { sources += source try { - registry.register(source.sourceName, source.metricRegistry) + val regName = buildRegistryName(source) + registry.register(regName, source.metricRegistry) } catch { case e: IllegalArgumentException => logInfo("Metrics already registered", e) } @@ -109,8 +138,9 @@ private[spark] class MetricsSystem private ( def removeSource(source: Source) { sources -= source + val regName = buildRegistryName(source) registry.removeMatching(new MetricFilter { - def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName) + def matches(name: String, metric: Metric): Boolean = name.startsWith(regName) }) } @@ -125,7 +155,7 @@ private[spark] class MetricsSystem private ( val source = Class.forName(classPath).newInstance() registerSource(source.asInstanceOf[Source]) } catch { - case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e) + case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 94944399b134a..12668b6c0988e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry} import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) +private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() - override val sourceName = "%s.DAGScheduler".format(sc.appName) + override val sourceName = "DAGScheduler" metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 64b32ae0edaac..100c9ba9b7809 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -43,38 +43,29 @@ import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} * spark.eventLog.buffer.kb - Buffer size to use when writing to output streams */ private[spark] class EventLoggingListener( - appName: String, + appId: String, + logBaseDir: String, sparkConf: SparkConf, hadoopConf: Configuration) extends SparkListener with Logging { import EventLoggingListener._ - def this(appName: String, sparkConf: SparkConf) = - this(appName, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) + def this(appId: String, logBaseDir: String, sparkConf: SparkConf) = + this(appId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 - private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") - private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_") - .toLowerCase + "-" + System.currentTimeMillis - val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") - + val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId) + val logDirName: String = logDir.split("/").last protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS)) // For testing. Keep track of all JSON serialized events that have been logged. private[scheduler] val loggedEvents = new ArrayBuffer[JValue] - /** - * Return only the unique application directory without the base directory. - */ - def getApplicationLogDir(): String = { - name - } - /** * Begin logging events. * If compression is used, log a file that indicates which compression library is used. @@ -184,6 +175,18 @@ private[spark] object EventLoggingListener extends Logging { } else "" } + /** + * Return a file-system-safe path to the log directory for the given application. + * + * @param logBaseDir A base directory for the path to the log directory for given application. + * @param appId A unique app ID. + * @return A path which consists of file-system-safe characters. + */ + def getLogDirPath(logBaseDir: String, appId: String): String = { + val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase + Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") + } + /** * Parse the event logging information associated with the logs in the given directory. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index a0be8307eff27..992c477493d8e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -23,6 +23,8 @@ package org.apache.spark.scheduler * machines become available and can launch tasks on them. */ private[spark] trait SchedulerBackend { + private val appId = "spark-application-" + System.currentTimeMillis + def start(): Unit def stop(): Unit def reviveOffers(): Unit @@ -33,10 +35,10 @@ private[spark] trait SchedulerBackend { def isReady(): Boolean = true /** - * The application ID associated with the job, if any. + * Get an application ID associated with the job. * - * @return The application ID, or None if the backend does not provide an ID. + * @return An application ID */ - def applicationId(): Option[String] = None + def applicationId(): String = appId } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1c1ce666eab0f..a129a434c9a1a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -31,6 +31,8 @@ import org.apache.spark.storage.BlockManagerId */ private[spark] trait TaskScheduler { + private val appId = "spark-application-" + System.currentTimeMillis + def rootPool: Pool def schedulingMode: SchedulingMode @@ -66,10 +68,10 @@ private[spark] trait TaskScheduler { blockManagerId: BlockManagerId): Boolean /** - * The application ID associated with the job, if any. + * Get an application ID associated with the job. * - * @return The application ID, or None if the backend does not provide an ID. + * @return An application ID */ - def applicationId(): Option[String] = None + def applicationId(): String = appId } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 633e892554c50..4dc550413c13c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -492,7 +492,7 @@ private[spark] class TaskSchedulerImpl( } } - override def applicationId(): Option[String] = backend.applicationId() + override def applicationId(): String = backend.applicationId() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5c5ecc8434d78..ed209d195ec9d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -68,9 +68,8 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") - val eventLogDir = sc.eventLogger.map(_.logDir) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, eventLogDir) + appUIAddress, sc.eventLogDir) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() @@ -129,7 +128,11 @@ private[spark] class SparkDeploySchedulerBackend( totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } - override def applicationId(): Option[String] = Option(appId) + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } private def waitForRegistration() = { registrationLock.synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3161f1ee9fa8a..90828578cd88f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -76,6 +76,8 @@ private[spark] class CoarseMesosSchedulerBackend( var nextMesosTaskId = 0 + @volatile var appId: String = _ + def newMesosTaskId(): Int = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -167,7 +169,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - logInfo("Registered as framework ID " + frameworkId.getValue) + appId = frameworkId.getValue + logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() @@ -313,4 +316,10 @@ private[spark] class CoarseMesosSchedulerBackend( slaveLost(d, s) } + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 4c49aa074ebc0..b11786368e661 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils /** @@ -62,6 +62,8 @@ private[spark] class MesosSchedulerBackend( var classLoader: ClassLoader = null + @volatile var appId: String = _ + override def start() { synchronized { classLoader = Thread.currentThread.getContextClassLoader @@ -177,7 +179,8 @@ private[spark] class MesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { val oldClassLoader = setClassLoader() try { - logInfo("Registered as framework ID " + frameworkId.getValue) + appId = frameworkId.getValue + logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() @@ -372,4 +375,10 @@ private[spark] class MesosSchedulerBackend( // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + override def applicationId(): String = + Option(appId).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 9ea25c2bc7090..58b78f041cd85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -88,6 +88,7 @@ private[spark] class LocalActor( private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) extends SchedulerBackend with ExecutorBackend { + private val appId = "local-" + System.currentTimeMillis var localActor: ActorRef = null override def start() { @@ -115,4 +116,6 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! StatusUpdate(taskId, state, serializedData) } + override def applicationId(): String = appId + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 49fea6d9e2a76..8569c6f3cbbc3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -22,10 +22,10 @@ import com.codahale.metrics.{Gauge,MetricRegistry} import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source -private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) +private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source { override val metricRegistry = new MetricRegistry() - override val sourceName = "%s.BlockManager".format(sc.appName) + override val sourceName = "BlockManager" metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index e42b181194727..3925f0ccbdbf0 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.metrics -import org.apache.spark.metrics.source.Source import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource +import org.apache.spark.metrics.source.Source -import scala.collection.mutable.ArrayBuffer +import com.codahale.metrics.MetricRegistry +import scala.collection.mutable.ArrayBuffer class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ @@ -39,6 +40,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod test("MetricsSystem with default config") { val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) + metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) @@ -49,6 +51,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod test("MetricsSystem with sources add") { val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) + metricsSystem.start() val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) @@ -60,4 +63,125 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod metricsSystem.registerSource(source) assert(metricsSystem.invokePrivate(sources()).length === 1) } + + test("MetricsSystem with Driver instance") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "driver" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === s"$appId.$executorId.${source.sourceName}") + } + + test("MetricsSystem with Driver instance and spark.app.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val executorId = "driver" + conf.set("spark.executor.id", executorId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Driver instance and spark.executor.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + conf.set("spark.app.id", appId) + + val instanceName = "driver" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Executor instance") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "executor.1" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === s"$appId.$executorId.${source.sourceName}") + } + + test("MetricsSystem with Executor instance and spark.app.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val executorId = "executor.1" + conf.set("spark.executor.id", executorId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with Executor instance and spark.executor.id is not set") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + conf.set("spark.app.id", appId) + + val instanceName = "executor" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + assert(metricName === source.sourceName) + } + + test("MetricsSystem with instance which is neither Driver nor Executor") { + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val appId = "testId" + val executorId = "dummyExecutorId" + conf.set("spark.app.id", appId) + conf.set("spark.executor.id", executorId) + + val instanceName = "testInstance" + val driverMetricsSystem = MetricsSystem.createMetricsSystem(instanceName, conf, securityMgr) + + val metricName = driverMetricsSystem.buildRegistryName(source) + + // Even if spark.app.id and spark.executor.id are set, they are not used for the metric name. + assert(metricName != s"$appId.$executorId.${source.sourceName}") + assert(metricName === source.sourceName) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index e5315bc93e217..3efa85431876b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -169,7 +169,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { // Verify logging directory exists val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) eventLogger.start() val logPath = new Path(eventLogger.logDir) assert(fileSystem.exists(logPath)) @@ -209,7 +211,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { // Verify that all information is correctly parsed before stop() val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) eventLogger.start() var eventLoggingInfo = EventLoggingListener.parseLoggingInfo(eventLogger.logDir, fileSystem) assertInfoCorrect(eventLoggingInfo, loggerStopped = false) @@ -228,7 +232,9 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { */ private def testEventLogging(compressionCodec: Option[String] = None) { val conf = getLoggingConf(logDirPath, compressionCodec) - val eventLogger = new EventLoggingListener("test", conf) + val logBaseDir = conf.get("spark.eventLog.dir") + val appId = EventLoggingListenerSuite.getUniqueApplicationId + val eventLogger = new EventLoggingListener(appId, logBaseDir, conf) val listenerBus = new LiveListenerBus val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey") @@ -408,4 +414,6 @@ object EventLoggingListenerSuite { } conf } + + def getUniqueApplicationId = "test-" + System.currentTimeMillis } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 7ab351d1b4d24..48114feee6233 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -155,7 +155,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { * This child listener inherits only the event buffering functionality, but does not actually * log the events. */ - private class EventMonster(conf: SparkConf) extends EventLoggingListener("test", conf) { + private class EventMonster(conf: SparkConf) + extends EventLoggingListener("test", "testdir", conf) { logger.close() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 99c8d13231aac..eb6e88cf5520d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.nio.ByteBuffer +import java.util.concurrent.Semaphore import scala.collection.mutable.ArrayBuffer @@ -36,6 +37,7 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { val receiver = new FakeReceiver val executor = new FakeReceiverSupervisor(receiver) + val executorStarted = new Semaphore(0) assert(executor.isAllEmpty) @@ -43,6 +45,7 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { val executingThread = new Thread() { override def run() { executor.start() + executorStarted.release(1) executor.awaitTermination() } } @@ -57,6 +60,9 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { } } + // Ensure executor is started + executorStarted.acquire() + // Verify that receiver was started assert(receiver.onStartCalled) assert(executor.isReceiverStarted) @@ -186,10 +192,10 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. */ class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - var otherThread: Thread = null - var receiving = false - var onStartCalled = false - var onStopCalled = false + @volatile var otherThread: Thread = null + @volatile var receiving = false + @volatile var onStartCalled = false + @volatile var onStopCalled = false def onStart() { otherThread = new Thread() { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 10cbeb8b94325..229b7a09f456b 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -47,6 +47,7 @@ class ExecutorRunnable( hostname: String, executorMemory: Int, executorCores: Int, + appAttemptId: String, securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { @@ -83,7 +84,7 @@ class ExecutorRunnable( ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, - localResources) + appAttemptId, localResources) logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index d7a7175d5e578..5cb4753de2e84 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -43,6 +43,7 @@ trait ExecutorRunnableUtil extends Logging { hostname: String, executorMemory: Int, executorCores: Int, + appId: String, localResources: HashMap[String, LocalResource]): List[String] = { // Extra options for the JVM val javaOpts = ListBuffer[String]() @@ -114,6 +115,7 @@ trait ExecutorRunnableUtil extends Logging { slaveId.toString, hostname.toString, executorCores.toString, + appId, "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4f4f1d2aaaade..e1af8d5a74cb1 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -57,6 +57,7 @@ object AllocationType extends Enumeration { private[yarn] abstract class YarnAllocator( conf: Configuration, sparkConf: SparkConf, + appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) @@ -295,6 +296,7 @@ private[yarn] abstract class YarnAllocator( executorHostname, executorMemory, executorCores, + appAttemptId.getApplicationId.toString, securityMgr) launcherPool.execute(executorRunnable) } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 200a30899290b..6bb4b82316ad4 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -155,6 +155,10 @@ private[spark] class YarnClientSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } - override def applicationId(): Option[String] = Option(appId).map(_.toString()) + override def applicationId(): String = + Option(appId).map(_.toString).getOrElse { + logWarning("Application ID is not initialized yet.") + super.applicationId + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 39436d0999663..3a186cfeb4eeb 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -48,6 +48,13 @@ private[spark] class YarnClusterSchedulerBackend( totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } - override def applicationId(): Option[String] = sc.getConf.getOption("spark.yarn.app.id") + override def applicationId(): String = + // In YARN Cluster mode, spark.yarn.app.id is expect to be set + // before user application is launched. + // So, if spark.yarn.app.id is not set, it is something wrong. + sc.getConf.getOption("spark.yarn.app.id").getOrElse { + logError("Application ID is not set.") + super.applicationId + } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 833be12982e71..0b5a92d87d722 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -47,6 +47,7 @@ class ExecutorRunnable( hostname: String, executorMemory: Int, executorCores: Int, + appId: String, securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { @@ -80,7 +81,7 @@ class ExecutorRunnable( ctx.setTokens(ByteBuffer.wrap(dob.getData())) val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, - localResources) + appId, localResources) logInfo(s"Setting up executor with environment: $env") logInfo("Setting up executor with commands: " + commands) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index e44a8db41b97e..2bbf5d7db8668 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -41,7 +41,7 @@ private[yarn] class YarnAllocationHandler( args: ApplicationMasterArguments, preferredNodes: collection.Map[String, collection.Set[SplitInfo]], securityMgr: SecurityManager) - extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { + extends YarnAllocator(conf, sparkConf, appAttemptId, args, preferredNodes, securityMgr) { override protected def releaseContainer(container: Container) = { amClient.releaseAssignedContainer(container.getId()) From cf1d32e3e1071829b152d4b597bf0a0d7a5629a2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 3 Oct 2014 14:22:11 -0700 Subject: [PATCH 36/36] [SPARK-1860] More conservative app directory cleanup. First contribution to the project, so apologize for any significant errors. This PR addresses [SPARK-1860]. The application directories are now cleaned up in a more conservative manner. Previously, app-* directories were cleaned up if the directory's timestamp was older than a given time. However, the timestamp on a directory does not reflect the modification times of the files in that directory. Therefore, app-* directories were wiped out even if the files inside them were created recently and possibly being used by Executor tasks. The solution is to change the cleanup logic to inspect all files within the app-* directory and only eliminate the app-* directory if all files in the directory are stale. Author: mcheah Closes #2609 from mccheah/worker-better-app-dir-cleanup and squashes the following commits: 87b5d03 [mcheah] [SPARK-1860] Using more string interpolation. Better error logging. 802473e [mcheah] [SPARK-1860] Cleaning up the logs generated when cleaning directories. e0a1f2e [mcheah] [SPARK-1860] Fixing broken unit test. 77a9de0 [mcheah] [SPARK-1860] More conservative app directory cleanup. --- .../spark/deploy/worker/ExecutorRunner.scala | 8 +--- .../apache/spark/deploy/worker/Worker.scala | 37 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 21 +++++++---- .../org/apache/spark/util/UtilsSuite.scala | 23 +++++++++--- 4 files changed, 62 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 00a43673e5cd3..71650cd773bcf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -42,7 +42,7 @@ private[spark] class ExecutorRunner( val workerId: String, val host: String, val sparkHome: File, - val workDir: File, + val executorDir: File, val workerUrl: String, val conf: SparkConf, var state: ExecutorState.Value) @@ -130,12 +130,6 @@ private[spark] class ExecutorRunner( */ def fetchAndRunExecutor() { try { - // Create the executor's working directory - val executorDir = new File(workDir, appId + "/" + execId) - if (!executorDir.mkdirs()) { - throw new IOException("Failed to create directory " + executorDir) - } - // Launch the process val command = getCommandSeq logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0c454e4138c96..3b13f43a1868c 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -18,15 +18,18 @@ package org.apache.spark.deploy.worker import java.io.File +import java.io.IOException import java.text.SimpleDateFormat import java.util.Date +import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import scala.language.postfixOps import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import org.apache.commons.io.FileUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} @@ -191,6 +194,7 @@ private[spark] class Worker( changeMaster(masterUrl, masterWebUiUrl) context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) if (CLEANUP_ENABLED) { + logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) } @@ -201,10 +205,23 @@ private[spark] class Worker( case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor val cleanupFuture = concurrent.future { - logInfo("Cleaning up oldest application directories in " + workDir + " ...") - Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS) - .foreach(Utils.deleteRecursively) + val appDirs = workDir.listFiles() + if (appDirs == null) { + throw new IOException("ERROR: Failed to list files in " + appDirs) + } + appDirs.filter { dir => + // the directory is used by an application - check that the application is not running + // when cleaning up + val appIdFromDir = dir.getName + val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + dir.isDirectory && !isAppStillRunning && + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + }.foreach { dir => + logInfo(s"Removing directory: ${dir.getPath}") + Utils.deleteRecursively(dir) + } } + cleanupFuture onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) @@ -233,8 +250,15 @@ private[spark] class Worker( } else { try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + + // Create the executor's working directory + val executorDir = new File(workDir, appId + "/" + execId) + if (!executorDir.mkdirs()) { + throw new IOException("Failed to create directory " + executorDir) + } + val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, sparkHome, workDir, akkaUrl, conf, ExecutorState.LOADING) + self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -242,12 +266,13 @@ private[spark] class Worker( master ! ExecutorStateChanged(appId, execId, manager.state, None, None) } catch { case e: Exception => { - logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) + master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None) } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9399ddab76331..a67124140f9da 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,6 +35,8 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.log4j.PropertyConfigurator @@ -705,17 +707,20 @@ private[spark] object Utils extends Logging { } /** - * Finds all the files in a directory whose last modified time is older than cutoff seconds. - * @param dir must be the path to a directory, or IllegalArgumentException is thrown - * @param cutoff measured in seconds. Files older than this are returned. + * Determines if a directory contains any files newer than cutoff seconds. + * + * @param dir must be the path to a directory, or IllegalArgumentException is thrown + * @param cutoff measured in seconds. Returns true if there are any files in dir newer than this. */ - def findOldFiles(dir: File, cutoff: Long): Seq[File] = { + def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = { val currentTimeMillis = System.currentTimeMillis - if (dir.isDirectory) { - val files = listFilesSafely(dir) - files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) } + if (!dir.isDirectory) { + throw new IllegalArgumentException (dir + " is not a directory!") } else { - throw new IllegalArgumentException(dir + " is not a directory!") + val files = FileUtils.listFilesAndDirs(dir, TrueFileFilter.TRUE, TrueFileFilter.TRUE) + val cutoffTimeInMillis = (currentTimeMillis - (cutoff * 1000)) + val newFiles = files.filter { _.lastModified > cutoffTimeInMillis } + newFiles.nonEmpty } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 70d423ba8a04d..e63d9d085e385 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -189,17 +189,28 @@ class UtilsSuite extends FunSuite { assert(Utils.getIteratorSize(iterator) === 5L) } - test("findOldFiles") { + test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories val child2: File = Utils.createTempDir(parent.getCanonicalPath) - // set the last modified time of child1 to 10 secs old - child1.setLastModified(System.currentTimeMillis() - (1000 * 10)) + val child3: File = Utils.createTempDir(child1.getCanonicalPath) + // set the last modified time of child1 to 30 secs old + child1.setLastModified(System.currentTimeMillis() - (1000 * 30)) - val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs - assert(result.size.equals(1)) - assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath)) + // although child1 is old, child2 is still new so return true + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child2.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + parent.setLastModified(System.currentTimeMillis - (1000 * 30)) + // although parent and its immediate children are new, child3 is still old + // we expect a full recursive search for new files. + assert(Utils.doesDirectoryContainAnyNewFiles(parent, 5)) + + child3.setLastModified(System.currentTimeMillis - (1000 * 30)) + assert(!Utils.doesDirectoryContainAnyNewFiles(parent, 5)) } test("resolveURI") {