diff --git a/.rat-excludes b/.rat-excludes index a2b5665a0be26..8954330bd10a7 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -39,4 +39,6 @@ work .*\.q golden test.out/* -.*iml \ No newline at end of file +.*iml +python/metastore/service.properties +python/metastore/db.lck diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 9c37fadb78d2f..69144e3e657bf 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -28,9 +28,9 @@ class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializ class TestMessage(val targetId: String) extends Message[String] with Serializable class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { - + var sc: SparkContext = _ - + after { if (sc != null) { sc.stop() diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index fa75842047c6a..23f5fdd43631b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -24,4 +24,4 @@ */ public interface FlatMapFunction extends Serializable { public Iterable call(T t) throws Exception; -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index d1fdec072443d..c48e92f535ff5 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -24,4 +24,4 @@ */ public interface FlatMapFunction2 extends Serializable { public Iterable call(T1 t1, T2 t2) throws Exception; -} \ No newline at end of file +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index fe54c34ffb1da..599c3ac9b57c0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -78,3 +78,12 @@ table.sortable thead { background-repeat: repeat-x; filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } + +span.kill-link { + margin-right: 2px; + color: gray; +} + +span.kill-link a { + color: gray; +} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 3d7692ea8a49e..a6e300d345786 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -24,13 +24,13 @@ import com.google.common.io.Files import org.apache.spark.util.Utils private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { - + var baseDir : File = null var fileDir : File = null var jarDir : File = null var httpServer : HttpServer = null var serverUri : String = null - + def initialize() { baseDir = Utils.createTempDir() fileDir = new File(baseDir, "files") @@ -43,24 +43,24 @@ private[spark] class HttpFileServer(securityManager: SecurityManager) extends Lo serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) } - + def stop() { httpServer.stop() } - + def addFile(file: File) : String = { addFileToDir(file, fileDir) serverUri + "/files/" + file.getName } - + def addJar(file: File) : String = { addFileToDir(file, jarDir) serverUri + "/jars/" + file.getName } - + def addFileToDir(file: File, dir: File) : String = { Files.copy(file, new File(dir, file.getName)) dir + "/" + file.getName } - + } diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index cb5df25fa48df..7e9b517f901a2 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -83,19 +83,19 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan } } - /** + /** * Setup Jetty to the HashLoginService using a single user with our * shared secret. Configure it to use DIGEST-MD5 authentication so that the password * isn't passed in plaintext. */ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { val constraint = new Constraint() - // use DIGEST-MD5 as the authentication mechanism + // use DIGEST-MD5 as the authentication mechanism constraint.setName(Constraint.__DIGEST_AUTH) constraint.setRoles(Array("user")) constraint.setAuthenticate(true) constraint.setDataConstraint(Constraint.DC_NONE) - + val cm = new ConstraintMapping() cm.setConstraint(constraint) cm.setPathSpec("/*") diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala index 87914a061f5d7..27892dbd2a0bc 100644 --- a/core/src/main/scala/org/apache/spark/Partition.scala +++ b/core/src/main/scala/org/apache/spark/Partition.scala @@ -25,7 +25,7 @@ trait Partition extends Serializable { * Get the split's index within its parent RDD */ def index: Int - + // A better default implementation of HashCode override def hashCode(): Int = index } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2237ee3bb7aad..b52f2d4f416b2 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -25,93 +25,93 @@ import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil -/** - * Spark class responsible for security. - * +/** + * Spark class responsible for security. + * * In general this class should be instantiated by the SparkEnv and most components - * should access it from that. There are some cases where the SparkEnv hasn't been + * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. - * + * * Spark currently supports authentication via a shared secret. * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do + * parameter. This parameter controls whether the Spark communication protocols do * authentication using the shared secret. This authentication is a basic handshake to * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * * Spark does not currently support encryption after authentication. - * + * * At this point spark has multiple communication protocols that need to be secured and * different underlying mechanisms are used depending on the protocol: * - * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. - * Akka remoting allows you to specify a secure cookie that will be exchanged - * and ensured to be identical in the connection handshake between the client - * and the server. If they are not identical then the client will be refused - * to connect to the server. There is no control of the underlying - * authentication mechanism so its not clear if the password is passed in + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in * plaintext or uses DIGEST-MD5 or some other mechanism. * Akka also has an option to turn on SSL, this option is not currently supported * but we could add a configuration option in the future. - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * to authenticate using DIGEST-MD5 via a single user and the shared secret. * Since we are using DIGEST-MD5, the shared secret is not passed on the wire * in plaintext. * We currently do not support SSL (https), but Jetty can be configured to use it * so we could add a configuration option for this in the future. - * + * * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default + * Any clients must specify the user and password. There is a default * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed * over the wire in plaintext. * Note that SASL is pluggable as to what mechanism it uses. We currently use * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. * Spark currently supports "auth" for the quality of protection, which means * the connection is not supporting integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which + * after authentication. SASL also supports "auth-int" and "auth-conf" which * SPARK could be support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to + * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the connectionManager does asynchronous messages passing, the SASL + * + * Since the connectionManager does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to + * A ConnectionId was added to be able to track connections and is used to * match up incoming messages with connections waiting for authentication. * If its acting as a client and trying to send a message to another ConnectionManager, * it blocks the thread calling sendMessage until the SASL negotiation has occurred. * The ConnectionManager tracks all the sendingConnections using the ConnectionId * and waits for the response from the server and does the handshake. * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a * companies normal login service. If an authentication filter is in place then the * SparkUI can be configured to check the logged in user against the list of users who * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters + * The filters can also be used for many different purposes. For instance filters * could be used for logging, encryption, or compression. - * + * * The exact mechanisms used to generate/distributed the shared secret is deployment specific. - * + * * For Yarn deployments, the secret is automatically generated using the Akka remote * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels @@ -121,7 +121,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use * filters to do authentication. That authentication then happens via the ResourceManager Proxy * and Spark will use that to do authorization against the view acls. - * + * * For other Spark deployments, the shared secret must be specified via the * spark.authenticate.secret config. * All the nodes (Master and Workers) and the applications need to have the same shared secret. @@ -152,7 +152,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. - // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // This is needed by the HTTP client fetching from the HttpServer. Put here so its // only set once. if (authOn) { Authenticator.setDefault( @@ -214,12 +214,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { def uiAclsEnabled(): Boolean = uiAclsOn /** - * Checks the given user against the view acl list to see if they have + * Checks the given user against the view acl list to see if they have * authorization to view the UI. If the UI acls must are disabled * via spark.ui.acls.enable, all users have view access. - * + * * @param user to see if is authorized - * @return true is the user has permission, otherwise false + * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 28923a1d8c340..a764c174d562c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1137,6 +1137,16 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelAllJobs() } + /** Cancel a given job if it's scheduled or running */ + private[spark] def cancelJob(jobId: Int) { + dagScheduler.cancelJob(jobId) + } + + /** Cancel a given stage and all jobs associated with it */ + private[spark] def cancelStage(stageId: Int) { + dagScheduler.cancelStage(stageId) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index d34e47e8cac22..4351ed74b67fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -20,5 +20,5 @@ package org.apache.spark class SparkException(message: String, cause: Throwable) extends Exception(message, cause) { - def this(message: String) = this(message, null) + def this(message: String) = this(message, null) } diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index b92ea01a877f7..f6703986bdf11 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -42,7 +42,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private val now = new Date() private val conf = new SerializableWritable(jobConf) - + private var jobID = 0 private var splitID = 0 private var attemptID = 0 @@ -58,8 +58,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) def preSetup() { setIDs(0, 0, 0) HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) - - val jCtxt = getJobContext() + + val jCtxt = getJobContext() getOutputCommitter().setupJob(jCtxt) } @@ -74,7 +74,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) val numfmt = NumberFormat.getInstance() numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - + val outputName = "part-" + numfmt.format(splitID) val path = FileOutputFormat.getOutputPath(conf.value) val fs: FileSystem = { @@ -85,7 +85,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } } - getOutputCommitter().setupTask(getTaskContext()) + getOutputCommitter().setupTask(getTaskContext()) writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) } @@ -103,18 +103,18 @@ class SparkHadoopWriter(@transient jobConf: JobConf) def commit() { val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() + val cmtr = getOutputCommitter() if (cmtr.needsTaskCommit(taCtxt)) { try { cmtr.commitTask(taCtxt) logInfo (taID + ": Committed") } catch { - case e: IOException => { + case e: IOException => { logError("Error committing the output of task: " + taID.value, e) cmtr.abortTask(taCtxt) throw e } - } + } } else { logWarning ("No need to commit output of task: " + taID.value) } @@ -144,7 +144,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } private def getJobContext(): JobContext = { - if (jobContext == null) { + if (jobContext == null) { jobContext = newJobContext(conf.value, jID.value) } jobContext @@ -175,7 +175,7 @@ object SparkHadoopWriter { val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } - + def createPathFromString(path: String, conf: JobConf): Path = { if (path == null) { throw new IllegalArgumentException("Output path is null") diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala index a2a871cbd3c31..5b14c4291d91a 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -44,12 +44,12 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * configurable in the future. */ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, new SparkSaslClientCallbackHandler(securityMgr)) /** * Used to initiate SASL handshake with server. - * @return response to challenge if needed + * @return response to challenge if needed */ def firstToken(): Array[Byte] = { synchronized { @@ -86,7 +86,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg } /** - * Disposes of any system resources or security-sensitive information the + * Disposes of any system resources or security-sensitive information the * SaslClient might be using. */ def dispose() { @@ -110,7 +110,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends CallbackHandler { - private val userName: String = + private val userName: String = SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) private val secretKey = securityMgr.getSecretKey() private val userPassword: Array[Char] = @@ -138,7 +138,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg rc.setText(rc.getDefaultText()) } case cb: RealmChoiceCallback => {} - case cb: Callback => throw + case cb: Callback => throw new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala index 11fcb2ae3a5c5..6161a6fb7ae85 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -64,7 +64,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi } /** - * Disposes of any system resources or security-sensitive information the + * Disposes of any system resources or security-sensitive information the * SaslServer might be using. */ def dispose() { @@ -88,7 +88,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) extends CallbackHandler { - private val userName: String = + private val userName: String = SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) override def handle(callbacks: Array[Callback]) { @@ -123,7 +123,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi ac.setAuthorizedID(authzid) } } - case cb: Callback => throw + case cb: Callback => throw new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 4597595a838e3..f3f59e47c3e98 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -31,7 +31,7 @@ import com.google.common.io.Files * projects. * * TODO: See if we can move this to the test codebase by specifying - * test dependencies between projects. + * test dependencies between projects. */ private[spark] object TestUtils { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2b32546c6854d..2659274c5e98e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -158,7 +158,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } def receiveBroadcast(): Boolean = { - // Receive meta-info about the size of broadcast data, + // Receive meta-info about the size of broadcast data, // the number of chunks it is divided into, etc. val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index c07838f798799..5da9615c9e9af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -43,7 +43,7 @@ private[spark] class ClientArguments(args: Array[String]) { // kill parameters var driverId: String = "" - + parse(args.toList) def parse(args: List[String]): Unit = args match { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 7f7372746f92e..180c853ce3096 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} -private[spark] class IndexPage(parent: HistoryServer) extends WebUIPage("") { +private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val appRows = parent.appIdToInfo.values.toSeq.sortBy { app => -app.lastUpdated } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index f495edcf1c6af..cf64700f9098c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -96,7 +96,7 @@ class HistoryServer( * this UI with the event logs in the provided base directory. */ def initialize() { - attachPage(new IndexPage(this)) + attachPage(new HistoryPage(this)) attachHandler(createStaticHandler(STATIC_RESOURCE_DIR, "/static")) logCheckingThread.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 30c2e4b1563d8..7ca3b08a28728 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class IndexPage(parent: MasterWebUI) extends WebUIPage("") { +private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private val master = parent.masterActorRef private val timeout = parent.timeout diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 939cf2ea9a678..a18b39fc95d64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -38,7 +38,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) /** Initialize all components of the server. */ def initialize() { attachPage(new ApplicationPage(this)) - attachPage(new IndexPage(this)) + attachPage(new MasterPage(this)) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) master.masterMetricsSystem.getServletHandlers.foreach(attachHandler) master.applicationMetricsSystem.getServletHandlers.foreach(attachHandler) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index d35d5be73ff97..3836bf219ed3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -32,8 +32,8 @@ private[spark] class WorkerArguments(args: Array[String]) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null - - // Check for settings in environment variables + + // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { port = System.getenv("SPARK_WORKER_PORT").toInt } @@ -49,7 +49,7 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } - + parse(args.toList) def parse(args: List[String]): Unit = args match { @@ -78,7 +78,7 @@ private[spark] class WorkerArguments(args: Array[String]) { case ("--work-dir" | "-d") :: value :: tail => workDir = value parse(tail) - + case "--webui-port" :: IntParam(value) :: tail => webUiPort = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 15b79872bc556..d4513118ced05 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -31,7 +31,7 @@ import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class IndexPage(parent: WorkerWebUI) extends WebUIPage("") { +private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val workerActor = parent.worker.self val worker = parent.worker val timeout = parent.timeout @@ -137,7 +137,7 @@ private[spark] class IndexPage(parent: WorkerWebUI) extends WebUIPage("") { .format(executor.appId, executor.execId)}>stdout stderr - + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 34b5acd2f9b64..0ad2edba2227f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -45,7 +45,7 @@ class WorkerWebUI( def initialize() { val logPage = new LogPage(this) attachPage(logPage) - attachPage(new IndexPage(this)) + attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 16887d8892b31..6327ac01663f6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -53,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, false) case RegisterExecutorFailed(message) => diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index ceff3a067d72a..38be2c58b333f 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -34,7 +34,7 @@ object ExecutorExitCode { logging the exception. */ val UNCAUGHT_EXCEPTION_TWICE = 51 - /** The default uncaught exception handler was reached, and the uncaught exception was an + /** The default uncaught exception handler was reached, and the uncaught exception was an OutOfMemoryError. */ val OOM = 52 @@ -43,10 +43,10 @@ object ExecutorExitCode { /** TachyonStore failed to initialize after many attempts. */ val TACHYON_STORE_FAILED_TO_INITIALIZE = 54 - + /** TachyonStore failed to create a local temporary directory after many attempts. */ val TACHYON_STORE_FAILED_TO_CREATE_DIR = 55 - + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -57,7 +57,7 @@ object ExecutorExitCode { case TACHYON_STORE_FAILED_TO_INITIALIZE => "TachyonStore failed to initialize." case TACHYON_STORE_FAILED_TO_CREATE_DIR => "TachyonStore failed to create a local temporary directory." - case _ => + case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { " (died from signal " + (exitCode - 128) + "?)" diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala index 208e77073fd03..218ed7b5d2d39 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala @@ -38,7 +38,7 @@ private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: Class override def addURL(url: URL) { super.addURL(url) } - override def findClass(name: String): Class[_] = { + override def findClass(name: String): Class[_] = { super.findClass(name) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 42c1200926fea..542dce65366b2 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -45,7 +45,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis case Some(s) => TimeUnit.valueOf(s.toUpperCase()) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } - + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match { diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 2f7576c53b482..3ffaaab23d0f5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -248,14 +248,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } } - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which + // outbox is used as a lock - ensure that it is always used as a leaf (since methods which // lock it are invoked in context of other locks) private val outbox = new Outbox() /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we + This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly + different purpose. This flag is to see if we need to force reregister for write even when we do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of + This can happen due to a race between adding pending buffers, and checking for existing of data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala index ffaab677d411a..d579c165a1917 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -18,7 +18,7 @@ package org.apache.spark.network private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId } private[spark] object ConnectionId { @@ -26,9 +26,9 @@ private[spark] object ConnectionId { def createConnectionIdFromString(connectionIdString: String): ConnectionId = { val res = connectionIdString.split("_").map(_.trim()) if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + + throw new Exception("Error converting ConnectionId string: " + connectionIdString + " to a ConnectionId Object") } new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) - } + } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index bdf586351ac14..cfee41c61362e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -79,7 +79,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] @@ -141,7 +141,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -509,7 +509,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private def handleClientAuthentication( waitingConn: SendingConnection, - securityMsg: SecurityMessage, + securityMsg: SecurityMessage, connectionId : ConnectionId) { if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed for id: " + waitingConn.connectionId) @@ -530,7 +530,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } return } - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString()) var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") @@ -546,7 +546,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } private def handleServerAuthentication( - connection: Connection, + connection: Connection, securityMsg: SecurityMessage, connectionId: ConnectionId) { if (!connection.isSaslComplete()) { @@ -561,7 +561,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId) + logDebug("Server sasl completed: " + connection.connectionId) } else { logDebug("Server sasl not completed: " + connection.connectionId) } @@ -571,7 +571,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security Message") sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } + } } catch { case e: Exception => { logError("Error in server auth negotiation: " + e) @@ -581,7 +581,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } else { - logDebug("connection already established for this connection id: " + connection.connectionId) + logDebug("connection already established for this connection id: " + connection.connectionId) } } @@ -609,8 +609,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return true } else { if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. logError("message sent that is not security negotiation message on connection " + "not authenticated yet, ignoring it!!") return true @@ -709,11 +709,11 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } else { - logDebug("Sasl already established ") + logDebug("Sasl already established ") } } - // allow us to add messages to the inbox for doing sasl negotiating + // allow us to add messages to the inbox for doing sasl negotiating private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) @@ -772,7 +772,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (((clock.getTime() - startTime) >= (authTimeout * 1000)) && (!connection.isSaslComplete())) { // took to long to authenticate the connection, something probably went wrong - throw new Exception("Took to long for authentication to " + connectionManagerId + + throw new Exception("Took to long for authentication to " + connectionManagerId + ", waited " + authTimeout + "seconds, failing.") } } @@ -794,7 +794,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } case None => { - logError("no messageStatus for failed message id: " + message.id) + logError("no messageStatus for failed message id: " + message.id) } } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala index 9d9b9dbdd5331..4894ecd41f6eb 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -37,11 +37,11 @@ private[spark] object ConnectionManagerTest extends Logging{ "[size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } - + if (args(0).startsWith("local")) { println("This runs only on a mesos cluster") } - + val sc = new SparkContext(args(0), "ConnectionManagerTest") val slavesFile = Source.fromFile(args(1)) val slaves = slavesFile.mkString.split("\n") @@ -50,7 +50,7 @@ private[spark] object ConnectionManagerTest extends Logging{ /* println("Slaves") */ /* slaves.foreach(println) */ val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 val count = if (args.length > 4) args(4).toInt else 3 val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + @@ -64,16 +64,16 @@ private[spark] object ConnectionManagerTest extends Logging{ (0 until count).foreach(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + val thisConnManagerId = connManager.id + connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { logInfo("Received [" + msg + "] from [" + id + "]") None }) val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip - - val startTime = System.currentTimeMillis + + val startTime = System.currentTimeMillis val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) @@ -84,7 +84,7 @@ private[spark] object ConnectionManagerTest extends Logging{ val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) - + val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * @@ -92,11 +92,11 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo(resultStr) resultStr }).collect() - - println("---------------------") - println("Run " + i) + + println("---------------------") + println("Run " + i) resultStrs.foreach(println) - println("---------------------") + println("---------------------") }) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 2b41c403b2e0a..9dc51e0d401f8 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala index 0d9f743b3624b..a1dfc4094cca7 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -26,33 +26,33 @@ import org.apache.spark._ import org.apache.spark.network._ /** - * SecurityMessage is class that contains the connectionId and sasl token + * SecurityMessage is class that contains the connectionId and sasl token * used in SASL negotiation. SecurityMessage has routines for converting * it to and from a BufferMessage so that it can be sent by the ConnectionManager * and easily consumed by users when received. * The api was modeled after BlockMessage. * - * The connectionId is the connectionId of the client side. Since + * The connectionId is the connectionId of the client side. Since * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This * is where the connectionId field is used. node_0 can lookup the connectionId to see if * it is in response to it being a client or if its in response to someone sending other data. - * + * * The format of a SecurityMessage as its sent is: * - Length of the ConnectionId - * - ConnectionId + * - ConnectionId * - Length of the token - * - Token + * - Token */ private[spark] class SecurityMessage() extends Logging { @@ -61,13 +61,13 @@ private[spark] class SecurityMessage() extends Logging { def set(byteArr: Array[Byte], newconnectionId: String) { if (byteArr == null) { - token = new Array[Byte](0) + token = new Array[Byte](0) } else { token = byteArr } connectionId = newconnectionId } - + /** * Read the given buffer and set the members of this class. */ @@ -91,17 +91,17 @@ private[spark] class SecurityMessage() extends Logging { buffer.clear() set(buffer) } - + def getConnectionId: String = { return connectionId } - + def getToken: Array[Byte] = { return token } - + /** - * Create a BufferMessage that can be sent by the ConnectionManager containing + * Create a BufferMessage that can be sent by the ConnectionManager containing * the security information from this class. * @return BufferMessage */ @@ -110,12 +110,12 @@ private[spark] class SecurityMessage() extends Logging { val buffers = new ArrayBuffer[ByteBuffer]() // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes + // connectionId is of type char so multiple the length by 2 to get number of bytes // 4 bytes for the length of token // token is a byte buffer so just take the length var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) + connectionId.foreach((x: Char) => buffer.putChar(x)) buffer.putInt(token.length) if (token.length > 0) { @@ -123,7 +123,7 @@ private[spark] class SecurityMessage() extends Logging { } buffer.flip() buffers += buffer - + var message = Message.createBufferMessage(buffers) logDebug("message total size is : " + message.size) message.isSecurityNeg = true @@ -136,7 +136,7 @@ private[spark] class SecurityMessage() extends Logging { } private[spark] object SecurityMessage { - + /** * Convert the given BufferMessage to a SecurityMessage by parsing the contents * of the BufferMessage and populating the SecurityMessage fields. diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 4164e81d3a8ae..136c1912045aa 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -36,8 +36,8 @@ private[spark] class FileHeader ( if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") } buf } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index eade07fbcbe37..cadd0c7ed19ba 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -44,7 +44,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { } } - /** + /** * Set a handler to be called when this PartialResult completes. Only one completion handler * is supported per PartialResult. */ @@ -60,7 +60,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { return this } - /** + /** * Set a handler to be called if this PartialResult's job fails. Only one failure handler * is supported per PartialResult. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 2306c9736b334..9ca971c8a4c27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -52,7 +52,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Compute the standard deviation of this RDD's elements. */ def stdev(): Double = stats().stdev - /** + /** * Compute the sample standard deviation of this RDD's elements (which corrects for bias in * estimating the standard deviation by dividing by N-1 instead of N). */ @@ -123,13 +123,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 - * + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. - * buckets array must be at least two elements + * buckets array must be at least two elements * All NaN entries are treated the same. If you have a NaN bucket it must be * the maximum value of the last position and all NaN entries will be counted * in that bucket. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index a84357b38414e..0c2cd7a24783b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -33,7 +33,7 @@ class PartitionerAwareUnionRDDPartition( val idx: Int ) extends Partition { var parents = rdds.map(_.partitions(idx)).toArray - + override val index = idx override def hashCode(): Int = idx diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index affda13df6531..c1001227151a5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -31,11 +31,11 @@ private[spark] class ApplicationEventListener extends SparkListener { def applicationStarted = startTime != -1 - def applicationFinished = endTime != -1 + def applicationCompleted = endTime != -1 def applicationDuration: Long = { val difference = endTime - startTime - if (applicationStarted && applicationFinished && difference > 0) difference else -1L + if (applicationStarted && applicationCompleted && difference > 0) difference else -1L } override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c41d6d75a1d49..c6cbf14e20069 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -511,6 +511,13 @@ class DAGScheduler( eventProcessActor ! AllJobsCancelled } + /** + * Cancel all jobs associated with a running or scheduled stage. + */ + def cancelStage(stageId: Int) { + eventProcessActor ! StageCancelled(stageId) + } + /** * Process one event retrieved from the event processing actor. * @@ -551,6 +558,9 @@ class DAGScheduler( submitStage(finalStage) } + case StageCancelled(stageId) => + handleStageCancellation(stageId) + case JobCancelled(jobId) => handleJobCancellation(jobId) @@ -560,11 +570,13 @@ class DAGScheduler( val activeInGroup = activeJobs.filter(activeJob => groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) - jobIds.foreach(handleJobCancellation) + jobIds.foreach(jobId => handleJobCancellation(jobId, + "as part of cancelled job group %s".format(groupId))) case AllJobsCancelled => // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation) + runningStages.map(_.jobId).foreach(jobId => handleJobCancellation(jobId, + "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -991,11 +1003,23 @@ class DAGScheduler( } } - private def handleJobCancellation(jobId: Int) { + private def handleStageCancellation(stageId: Int) { + if (stageIdToJobIds.contains(stageId)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray + jobsThatUseStage.foreach(jobId => { + handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId)) + }) + } else { + logInfo("No active jobs to kill for Stage " + stageId) + } + } + + private def handleJobCancellation(jobId: Int, reason: String = "") { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled", None) + failJobAndIndependentStages(jobIdToActiveJob(jobId), + "Job %d cancelled %s".format(jobId, reason), None) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 04c53d468465a..7367c08b5d324 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -44,6 +44,8 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent + private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent @@ -54,7 +56,7 @@ private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] -case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class CompletionEvent( task: Task[_], diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 76f3e327d60b8..545fa453b7ccf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -1,107 +1,107 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.util.concurrent.LinkedBlockingQueue - -import org.apache.spark.Logging - -/** - * Asynchronously passes SparkListenerEvents to registered SparkListeners. - * - * Until start() is called, all posted events are only buffered. Only after this listener bus - * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). - */ -private[spark] class LiveListenerBus extends SparkListenerBus with Logging { - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - private var started = false - private val listenerThread = new Thread("SparkListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - postToAll(event) - } - } - } - - // Exposed for testing - @volatile private[spark] var stopCalled = false - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - */ - def start() { - if (started) { - throw new IllegalStateException("Listener bus already started!") - } - listenerThread.start() - started = true - } - - def post(event: SparkListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") - queueFullErrorMessageLogged = true - } - } - - /** - * Waits until there are no more events in the queue, or until the specified time has elapsed. - * Used for testing only. Returns true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify - * add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - def stop() { - stopCalled = true - if (!started) { - throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") - } - post(SparkListenerShutdown) - listenerThread.join() - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.LinkedBlockingQueue + +import org.apache.spark.Logging + +/** + * Asynchronously passes SparkListenerEvents to registered SparkListeners. + * + * Until start() is called, all posted events are only buffered. Only after this listener bus + * has started will events be actually propagated to all attached listeners. This listener bus + * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). + */ +private[spark] class LiveListenerBus extends SparkListenerBus with Logging { + + /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + private var queueFullErrorMessageLogged = false + private var started = false + private val listenerThread = new Thread("SparkListenerBus") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + if (event == SparkListenerShutdown) { + // Get out of the while loop and shutdown the daemon thread + return + } + postToAll(event) + } + } + } + + // Exposed for testing + @volatile private[spark] var stopCalled = false + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + */ + def start() { + if (started) { + throw new IllegalStateException("Listener bus already started!") + } + listenerThread.start() + started = true + } + + def post(event: SparkListenerEvent) { + val eventAdded = eventQueue.offer(event) + if (!eventAdded && !queueFullErrorMessageLogged) { + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with the " + + "rate at which tasks are being started by the scheduler.") + queueFullErrorMessageLogged = true + } + } + + /** + * Waits until there are no more events in the queue, or until the specified time has elapsed. + * Used for testing only. Returns true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!eventQueue.isEmpty) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify + * add overhead in the general case. */ + Thread.sleep(10) + } + true + } + + def stop() { + stopCalled = true + if (!started) { + throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") + } + post(SparkListenerShutdown) + listenerThread.join() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index c478e685641d7..06b041e1fd9a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -194,12 +194,10 @@ private[spark] class CoarseMesosSchedulerBackend( .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", sc.executorMemory)) .build() - d.launchTasks(Collections.singletonList(offer.getId), - Collections.singletonList(task), - filters) + d.launchTasks(offer.getId, Collections.singletonList(task), filters) } else { // Filter it out - d.declineOffer(offer.getId, filters) + d.launchTasks(offer.getId, Collections.emptyList[MesosTaskInfo](), filters) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index f878ae338fc95..dfdcafe19fb93 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -223,7 +223,7 @@ private[spark] class MesosSchedulerBackend( // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? for (i <- 0 until offers.size) { - d.launchTasks(Collections.singletonList(offers(i).getId), mesosTasks(i), filters) + d.launchTasks(offers(i).getId, mesosTasks(i), filters) } } } finally { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 2fbbda5b76c74..ace9cd51c96b7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -240,7 +240,7 @@ object BlockFetcherIterator { override def numRemoteBlocks: Int = numRemote override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead - + // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a2a729130091f..df9bb4044e37a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -388,7 +388,7 @@ private[spark] class BlockManager( logDebug("Block " + blockId + " not found in memory") } } - + // Look for the block in Tachyon if (level.useOffHeap) { logDebug("Getting block " + blockId + " from tachyon") @@ -1031,7 +1031,7 @@ private[spark] class BlockManager( memoryStore.clear() diskStore.clear() if (tachyonInitialized) { - tachyonStore.clear() + tachyonStore.clear() } metadataCleaner.cancel() broadcastCleaner.cancel() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index 7168ae18c2615..337b45b727dec 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -37,7 +37,7 @@ private[spark] class BlockMessage() { private var id: BlockId = null private var data: ByteBuffer = null private var level: StorageLevel = null - + def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -75,13 +75,13 @@ private[spark] class BlockMessage() { idBuilder += buffer.getChar() } id = BlockId(idBuilder.toString) - + if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() val replication = buffer.getInt() level = StorageLevel(booleanInt, replication) - + val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { @@ -108,12 +108,12 @@ private[spark] class BlockMessage() { buffer.clear() set(buffer) } - + def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data def getLevel: StorageLevel = level - + def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() @@ -127,7 +127,7 @@ private[spark] class BlockMessage() { buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) buffer.flip() buffers += buffer - + buffer = ByteBuffer.allocate(4).putInt(data.remaining) buffer.flip() buffers += buffer @@ -140,7 +140,7 @@ private[spark] class BlockMessage() { buffers += data } - + /* println() println("BlockMessage: ") @@ -158,7 +158,7 @@ private[spark] class BlockMessage() { } override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + + "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } @@ -168,7 +168,7 @@ private[spark] object BlockMessage { val TYPE_GET_BLOCK: Int = 1 val TYPE_GOT_BLOCK: Int = 2 val TYPE_PUT_BLOCK: Int = 3 - + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { val newBlockMessage = new BlockMessage() newBlockMessage.set(bufferMessage) @@ -192,7 +192,7 @@ private[spark] object BlockMessage { newBlockMessage.set(gotBlock) newBlockMessage } - + def fromPutBlock(putBlock: PutBlock): BlockMessage = { val newBlockMessage = new BlockMessage() newBlockMessage.set(putBlock) @@ -206,7 +206,7 @@ private[spark] object BlockMessage { val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) - + println(B.getId + " " + B.getLevel) println(C.getId + " " + C.getLevel) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index dc62b1efaa7d4..973d85c0a9b3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -27,16 +27,16 @@ import org.apache.spark.network._ private[spark] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { - + def this(bm: BlockMessage) = this(Array(bm)) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int) = blockMessages(i) def iterator = blockMessages.iterator - def length = blockMessages.length + def length = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis @@ -62,15 +62,15 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Trying to convert buffer " + newBuffer + " to block message") val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage + newBlockMessages += newBlockMessage buffer.position(buffer.position() + size) } val finishTime = System.currentTimeMillis logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages + this.blockMessages = newBlockMessages } - + def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -83,7 +83,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) buffers ++= bufferMessage.buffers logDebug("Added " + bufferMessage) }) - + logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* @@ -103,13 +103,13 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) } private[spark] object BlockMessageArray { - + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() newBlockMessageArray.set(bufferMessage) newBlockMessageArray } - + def main(args: Array[String]) { val blockMessages = (0 until 10).map { i => @@ -124,10 +124,10 @@ private[spark] object BlockMessageArray { } val blockMessageArray = new BlockMessageArray(blockMessages) println("Block message array created") - + val bufferMessage = blockMessageArray.toBufferMessage println("Converted to buffer message") - + val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() @@ -137,7 +137,7 @@ private[spark] object BlockMessageArray { buffer.rewind() }) newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) + val newBufferMessage = Message.createBufferMessage(newBuffer) println("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) @@ -147,7 +147,7 @@ private[spark] object BlockMessageArray { case BlockMessage.TYPE_PUT_BLOCK => { val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) println(pB) - } + } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) println(gB) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 07255aa366a6d..7ed371326855d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -42,24 +42,22 @@ class StorageStatus( def memRemaining : Long = maxMem - memUsed() - def rddBlocks = blocks.flatMap { - case (rdd: RDDBlockId, status) => Some(rdd, status) - case _ => None - } + def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) } } @DeveloperApi private[spark] class RDDInfo( - val id: Int, - val name: String, - val numPartitions: Int, - val storageLevel: StorageLevel) extends Ordered[RDDInfo] { + val id: Int, + val name: String, + val numPartitions: Int, + val storageLevel: StorageLevel) + extends Ordered[RDDInfo] { var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L - var tachyonSize= 0L + var tachyonSize = 0L override def toString = { import Utils.bytesToString diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 3609073cecf7c..62a4e3d0f6a42 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -104,10 +104,12 @@ private[spark] object JettyUtils extends Logging { def createRedirectHandler( srcPath: String, destPath: String, + beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = ""): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString response.sendRedirect(newUrl) @@ -137,7 +139,7 @@ private[spark] object JettyUtils extends Logging { private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) filters.foreach { - case filter : String => + case filter : String => if (!filter.isEmpty) { logInfo("Adding filter: " + filter) val holder : FilterHolder = new FilterHolder() @@ -152,7 +154,7 @@ private[spark] object JettyUtils extends Logging { if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) } } - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bca4c3c42d27f..2fef1a635427c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -54,12 +54,15 @@ private[spark] class SparkUI( /** Initialize all components of the server. */ def initialize() { listenerBus.addListener(storageStatusListener) - attachTab(new JobProgressTab(this)) + val jobProgressTab = new JobProgressTab(this) + attachTab(jobProgressTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/stages", basePath)) + attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler( + createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) if (live) { sc.env.metricsSystem.getServletHandlers.foreach(attachHandler) } @@ -81,7 +84,12 @@ private[spark] class SparkUI( logInfo("Stopped Spark web UI at %s".format(appUIAddress)) } - private[spark] def appUIAddress = "http://" + publicHostName + ":" + boundPort + /** + * Return the application UI host:port. This does not include the scheme (http://). + */ + private[spark] def appUIHostPort = publicHostName + ":" + boundPort + + private[spark] def appUIAddress = s"http://$appUIHostPort" } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e9b83922a4b1a..6a2d652528d8a 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -154,17 +154,8 @@ private[spark] object UIUtils extends Logging { type="text/css" /> {appName} - {title} - - + -
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 70578c3eb87c8..b347eb1b83c1f 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} -private[ui] class IndexPage(parent: EnvironmentTab) extends WebUIPage("") { +private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath private val listener = parent.listener diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 7797057fa1aa9..03b46e1bd59af 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -25,7 +25,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "envi val basePath = parent.basePath val listener = new EnvironmentListener - attachPage(new IndexPage(this)) + attachPage(new EnvironmentPage(this)) parent.registerListener(listener) } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 56c3887923758..c1e69f6cdaffb 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[ui] class IndexPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath private val listener = parent.listener diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index e9ec18a3e74af..5678bf34ac730 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -29,7 +29,7 @@ private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "execut val basePath = parent.basePath val listener = new ExecutorsListener(parent.storageStatusListener) - attachPage(new IndexPage(this)) + attachPage(new ExecutorsPage(this)) parent.registerListener(listener) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 18559f732d2a3..0db4afa701b41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -222,7 +222,7 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { - environmentUpdate + schedulingMode = environmentUpdate .environmentDetails("Spark Properties").toMap .get("spark.scheduler.mode") .map(SchedulingMode.withName) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 12c82796349c9..34ff2ac34a7ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class IndexPage(parent: JobProgressTab) extends WebUIPage("") { +private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath private val live = parent.live @@ -40,7 +40,8 @@ private[ui] class IndexPage(parent: JobProgressTab) extends WebUIPage("") { val failedStages = listener.failedStages.reverse.toSeq val now = System.currentTimeMillis - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = + new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent, parent.killEnabled) val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index da9de035f89f1..3308c8c8a3d37 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.jobs +import javax.servlet.http.HttpServletRequest + import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode import org.apache.spark.ui.{SparkUI, WebUITab} @@ -28,12 +30,27 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag val live = parent.live val sc = parent.sc val conf = if (live) sc.conf else new SparkConf + val killEnabled = conf.getBoolean("spark.ui.killEnabled", true) val listener = new JobProgressListener(conf) - attachPage(new IndexPage(this)) + attachPage(new JobProgressPage(this)) attachPage(new StagePage(this)) attachPage(new PoolPage(this)) parent.registerListener(listener) def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + + def handleKillRequest(request: HttpServletRequest) = { + if (killEnabled) { + val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean + val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt + if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { + sc.cancelStage(stageId) + } + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5cc1fcd10a08d..8c5b1f55fd2dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -27,7 +27,11 @@ import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ -private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressTab) { +private[ui] class StageTable( + stages: Seq[StageInfo], + parent: JobProgressTab, + killEnabled: Boolean = false) { + private val basePath = parent.basePath private val listener = parent.listener private lazy val isFairScheduler = parent.isFairScheduler @@ -71,15 +75,28 @@ private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressTab) {
} - /** Render an HTML row that represents a stage */ - private def stageRow(s: StageInfo): Seq[Node] = { - val poolName = listener.stageIdToPool.get(s.stageId) + private def makeDescription(s: StageInfo): Seq[Node] = { + // scalastyle:off + val killLink = if (killEnabled) { + + (
kill) + + } + // scalastyle:on + val nameLink = {s.name} - val description = listener.stageIdToDescription.get(s.stageId) - .map(d =>
{d}
{nameLink}
).getOrElse(nameLink) + + listener.stageIdToDescription.get(s.stageId) + .map(d =>
{d}
{nameLink} {killLink}
) + .getOrElse(
{killLink}{nameLink}
) + } + + /** Render an HTML row that represents a stage */ + private def stageRow(s: StageInfo): Seq[Node] = { + val poolName = listener.stageIdToPool.get(s.stageId) val submissionTime = s.submissionTime match { case Some(t) => UIUtils.formatDate(new Date(t)) case None => "Unknown" @@ -118,7 +135,7 @@ private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressTab) { }} - {description} + {makeDescription(s)} {submissionTime} {formattedDuration} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index c5cfee777aab5..b66edd91f56c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -26,7 +26,7 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ -private[ui] class IndexPage(parent: StorageTab) extends WebUIPage("") { +private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath private val listener = parent.listener diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 0886950c3f8e6..56429f6c07fcd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -29,7 +29,7 @@ private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage" val basePath = parent.basePath val listener = new StorageListener(parent.storageStatusListener) - attachPage(new IndexPage(this)) + attachPage(new StoragePage(this)) attachPage(new RddPage(this)) parent.registerListener(listener) } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index cdbbc65292188..2d05e09b10948 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -45,7 +45,7 @@ private[spark] object ClosureCleaner extends Logging { private def isClosure(cls: Class[_]): Boolean = { cls.getName.contains("$anonfun$") } - + // Get a list of the classes of the outer objects of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching @@ -63,7 +63,7 @@ private[spark] object ClosureCleaner extends Logging { } Nil } - + // Get a list of the outer objects for a given closure object. private def getOuterObjects(obj: AnyRef): List[AnyRef] = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { @@ -76,7 +76,7 @@ private[spark] object ClosureCleaner extends Logging { } Nil } - + private def getInnerClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) var stack = List[Class[_]](obj.getClass) @@ -92,7 +92,7 @@ private[spark] object ClosureCleaner extends Logging { } return (seen - obj.getClass).toList } - + private def createNullValue(cls: Class[_]): AnyRef = { if (cls.isPrimitive) { new java.lang.Byte(0: Byte) // Should be convertible to any primitive type @@ -100,13 +100,13 @@ private[spark] object ClosureCleaner extends Logging { null } } - + def clean(func: AnyRef) { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) val outerObjects = getOuterObjects(func) - + val accessedFields = Map[Class[_], Set[String]]() for (cls <- outerClasses) accessedFields(cls) = Set[String]() @@ -143,7 +143,7 @@ private[spark] object ClosureCleaner extends Logging { field.set(outer, value) } } - + if (outer != null) { // logInfo("2: Setting $outer on " + func.getClass + " to " + outer); val field = func.getClass.getDeclaredField("$outer") @@ -151,7 +151,7 @@ private[spark] object ClosureCleaner extends Logging { field.set(func, outer) } } - + private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { // logInfo("Creating a " + cls + " with outer = " + outer) if (!inInterpreter) { @@ -192,7 +192,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor } } } - + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { // Check for calls a getter method for a variable in an interpreter wrapper object. @@ -209,12 +209,12 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null - + override def visit(version: Int, access: Int, name: String, sig: String, superName: String, interfaces: Array[String]) { myName = name } - + override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { new MethodVisitor(ASM4) { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d990fd49ef834..465835ea7fe29 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -88,30 +88,27 @@ private[spark] object JsonProtocol { def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = { val taskInfo = taskStart.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing ("Event" -> Utils.getFormattedClassName(taskStart)) ~ ("Stage ID" -> taskStart.stageId) ~ - ("Task Info" -> taskInfoJson) + ("Task Info" -> taskInfoToJson(taskInfo)) } def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = { val taskInfo = taskGettingResult.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~ - ("Task Info" -> taskInfoJson) + ("Task Info" -> taskInfoToJson(taskInfo)) } def taskEndToJson(taskEnd: SparkListenerTaskEnd): JValue = { val taskEndReason = taskEndReasonToJson(taskEnd.reason) val taskInfo = taskEnd.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing val taskMetrics = taskEnd.taskMetrics val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ ("Stage ID" -> taskEnd.stageId) ~ ("Task Type" -> taskEnd.taskType) ~ ("Task End Reason" -> taskEndReason) ~ - ("Task Info" -> taskInfoJson) ~ + ("Task Info" -> taskInfoToJson(taskInfo)) ~ ("Task Metrics" -> taskMetricsJson) } @@ -505,6 +502,9 @@ private[spark] object JsonProtocol { } def taskMetricsFromJson(json: JValue): TaskMetrics = { + if (json == JNothing) { + return TaskMetrics.empty + } val metrics = new TaskMetrics metrics.hostname = (json \ "Host Name").extract[String] metrics.executorDeserializeTime = (json \ "Executor Deserialize Time").extract[Long] @@ -611,7 +611,7 @@ private[spark] object JsonProtocol { val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize - rddInfo.tachyonSize = tachyonSize + rddInfo.tachyonSize = tachyonSize rddInfo.diskSize = diskSize rddInfo } diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala index 8266e5e495efc..e5c732a5a559b 100644 --- a/core/src/main/scala/org/apache/spark/util/NextIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.util /** Provides a basic/boilerplate Iterator implementation. */ private[spark] abstract class NextIterator[U] extends Iterator[U] { - + private var gotNext = false private var nextValue: U = _ private var closed = false @@ -34,7 +34,7 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] { * This convention is required because `null` may be a valid value, * and using `Option` seems like it might create unnecessary Some/None * instances, given some iterators might be called in a tight loop. - * + * * @return U, or set 'finished' when done */ protected def getNext(): U diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index 732748a7ff82b..d80eed455c427 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -62,10 +62,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (n == 0) { mu = other.mu m2 = other.m2 - n = other.n + n = other.n maxValue = other.maxValue minValue = other.minValue - } else if (other.n != 0) { + } else if (other.n != 0) { val delta = other.mu - mu if (other.n * 10 < n) { mu = mu + (delta * other.n) / (n + other.n) diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index 3c8f94a416c65..1a647fa1c9d84 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -136,7 +136,7 @@ object Vector { def ones(length: Int) = Vector(length, _ => 1) /** - * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers + * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. */ def random(length: Int, random: Random = new XORShiftRandom()) = diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 8a4cdea2fa7b1..7f220383f9f8b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -25,28 +25,28 @@ import scala.util.hashing.MurmurHash3 import org.apache.spark.util.Utils.timeIt /** - * This class implements a XORShift random number generator algorithm + * This class implements a XORShift random number generator algorithm * Source: * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. * @see Paper * This implementation is approximately 3.5 times faster than * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due - * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG * for each thread. */ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { - + def this() = this(System.nanoTime) private var seed = XORShiftRandom.hashSeed(init) // we need to just override next - this will be called by nextInt, nextDouble, // nextGaussian, nextLong, etc. - override protected def next(bits: Int): Int = { + override protected def next(bits: Int): Int = { var nextSeed = seed ^ (seed << 21) nextSeed ^= (nextSeed >>> 35) - nextSeed ^= (nextSeed << 4) + nextSeed ^= (nextSeed << 4) seed = nextSeed (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] } @@ -89,7 +89,7 @@ private[spark] object XORShiftRandom { val million = 1e6.toInt val javaRand = new JavaRandom(seed) val xorRand = new XORShiftRandom(seed) - + // this is just to warm up the JIT - we're not timing anything timeIt(1e6.toInt) { javaRand.nextInt() @@ -97,9 +97,9 @@ private[spark] object XORShiftRandom { } val iters = timeIt(numIters)(_) - + /* Return results as a map instead of just printing to screen - in case the user wants to do something with them */ + in case the user wants to do something with them */ Map("javaTime" -> iters {javaRand.nextInt()}, "xorTime" -> iters {xorRand.nextInt()}) diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index c5f24c66ce0c1..c645e4cbe8132 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -37,7 +37,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -54,14 +54,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) } actorSystem.shutdown() @@ -75,7 +75,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -91,7 +91,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf); - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( @@ -127,7 +127,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -180,7 +180,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -204,8 +204,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) } actorSystem.shutdown() diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 9cbdfc54a3dc8..7f59bdcce4cc7 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -39,7 +39,7 @@ class DriverSuite extends FunSuite with Timeouts { failAfter(60 seconds) { Utils.executeAndGetOutput( Seq("./bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), + new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) } } diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index aee9ab9091dac..d651fbbac4e97 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -45,7 +45,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val pw = new PrintWriter(textFile) pw.println("100") pw.close() - + val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) @@ -53,7 +53,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) - + val in = new FileInputStream(textFile) val buffer = new Array[Byte](10240) var nRead = 0 diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 01af94077144a..b9b668d3cc62a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -106,7 +106,7 @@ class FileSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) nums.saveAsSequenceFile(outputDir) // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 0b5ed6d77034b..5e538d6fab2a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -45,4 +45,4 @@ class WorkerWatcherSuite extends FunSuite { actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) assert(!actorRef.underlyingActor.isShutDown) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 09e35bfc8f85f..e89b296d41026 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -42,7 +42,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { sc = new SparkContext("local", "test") - + // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) } diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index a4381a8b974df..4df36558b6d4b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -34,14 +34,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === "2") assert(slices(2).mkString(",") === "3") } - + test("one slice") { val data = Array(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 1) assert(slices.size === 1) assert(slices(0).mkString(",") === "1,2,3") } - + test("equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) val slices = ParallelCollectionRDD.slice(data, 3) @@ -50,7 +50,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === "4,5,6") assert(slices(2).mkString(",") === "7,8,9") } - + test("non-equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) val slices = ParallelCollectionRDD.slice(data, 3) @@ -77,14 +77,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === (33 to 66).mkString(",")) assert(slices(2).mkString(",") === (67 to 100).mkString(",")) } - + test("empty data") { val data = new Array[Int](0) val slices = ParallelCollectionRDD.slice(data, 5) assert(slices.size === 5) for (slice <- slices) assert(slice.size === 0) } - + test("zero slices") { val data = Array(1, 2, 3) intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } @@ -94,7 +94,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = Array(1, 2, 3) intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } } - + test("exclusive ranges sliced into ranges") { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) @@ -102,7 +102,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[Range])) } - + test("inclusive ranges sliced into ranges") { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) @@ -124,7 +124,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(range.step === 1, "slice " + i + " step") } } - + test("random array tests") { val gen = for { d <- arbitrary[List[Int]] @@ -141,7 +141,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } check(prop) } - + test("random exclusive range tests") { val gen = for { a <- Gen.choose(-100, 100) @@ -177,7 +177,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } check(prop) } - + test("exclusive ranges of longs") { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) @@ -185,7 +185,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("inclusive ranges of longs") { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) @@ -193,7 +193,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("exclusive ranges of doubles") { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("inclusive ranges of doubles") { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 25973348a7837..1901330d8b188 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -274,37 +274,42 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("coalesced RDDs with locality, large scale (10K partitions)") { // large scale experiment import collection.mutable - val rnd = scala.util.Random val partitions = 10000 val numMachines = 50 val machines = mutable.ListBuffer[String]() - (1 to numMachines).foreach(machines += "m"+_) - - val blocks = (1 to partitions).map(i => - { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) - - val data2 = sc.makeRDD(blocks) - val coalesced2 = data2.coalesce(numMachines*2) - - // test that you get over 90% locality in each group - val minLocality = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%") - - // test that the groups are load balanced with 100 +/- 20 elements in each - val maxImbalance = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) - .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) - assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) - - val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs - val coalesced3 = data3.coalesce(numMachines*2) - val minLocality2 = coalesced3.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.0).toInt + "%") + (1 to numMachines).foreach(machines += "m" + _) + val rnd = scala.util.Random + for (seed <- 1 to 5) { + rnd.setSeed(seed) + + val blocks = (1 to partitions).map { i => + (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) + } + + val data2 = sc.makeRDD(blocks) + val coalesced2 = data2.coalesce(numMachines * 2) + + // test that you get over 90% locality in each group + val minLocality = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + + (minLocality * 100.0).toInt + "%") + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100 - curr), dev)) + assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) + + val data3 = sc.makeRDD(blocks).map(i => i * 2) // derived RDD to test *current* pref locs + val coalesced3 = data3.coalesce(numMachines * 2) + val minLocality2 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) + assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + + (minLocality2 * 100.0).toInt + "%") + } } test("zipped RDDs") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a74724d785ad3..db4df1d1212ff 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -290,7 +290,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val rdd = makeRdd(1, Nil) val jobId = submit(rdd, Array(0)) cancel(jobId) - assert(failure.getMessage === s"Job $jobId cancelled") + assert(failure.getMessage === s"Job $jobId cancelled ") assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(0)) assert(sparkListener.failedStages.size === 1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index dc704e07a81de..4cdccdda6f72e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -216,7 +216,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc test("onTaskGettingResult() called when result fetched remotely") { val listener = new SaveTaskEvents sc.addSparkListener(listener) - + // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = @@ -236,7 +236,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc test("onTaskGettingResult() not called when result sent directly") { val listener = new SaveTaskEvents sc.addSparkListener(listener) - + // Make a task whose result is larger than the akka frame size val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 356e28dd19bc5..2fb750d9ee378 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -264,7 +264,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin test("Scheduler does not always schedule tasks on the same workers") { sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) + val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. val dagScheduler = new DAGScheduler(sc, taskScheduler) { diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 0332a2a0539ee..b85c483ca2a08 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -98,8 +98,8 @@ class UISuite extends FunSuite { val server = new Server(startPort) Try { server.start() } match { - case Success(s) => - case Failure(e) => + case Success(s) => + case Failure(e) => // Either case server port is busy hence setup for test complete } val serverInfo1 = JettyUtils.startJettyServer( @@ -125,4 +125,18 @@ class UISuite extends FunSuite { case Failure(e) => } } + + test("verify appUIAddress contains the scheme") { + withSpark(new SparkContext("local", "test")) { sc => + val uiAddress = sc.ui.appUIAddress + assert(uiAddress.equals("http://" + sc.ui.appUIHostPort)) + } + } + + test("verify appUIAddress contains the port") { + withSpark(new SparkContext("local", "test")) { sc => + val splitUIAddress = sc.ui.appUIAddress.split(':') + assert(splitUIAddress(2).toInt == sc.ui.boundPort) + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 439e5644e20a3..d7e48e633e0ee 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -69,7 +69,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - + def getX = x def run(): Int = { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index f75297a02dc8b..16470bb7bf60d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -523,8 +523,8 @@ class JsonProtocolSuite extends FunSuite { 700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics": {"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks": [{"Block ID":{"Type":"RDDBlockId","RDD ID":0,"Split Index":0},"Status": - {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false, - "Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}} + {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false, + "Deserialized":false,"Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}} """ private val jobStartJsonString = diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index e1446cbc90bdb..32d74d0500b72 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -32,7 +32,7 @@ class NextIteratorSuite extends FunSuite with ShouldMatchers { i.hasNext should be === false intercept[NoSuchElementException] { i.next() } } - + test("two iterations") { val i = new StubIterator(Buffer(1, 2)) i.hasNext should be === true @@ -70,7 +70,7 @@ class NextIteratorSuite extends FunSuite with ShouldMatchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - + override def getNext() = { if (ints.size == 0) { finished = true diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 757476efdb789..39199a1a17ccd 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -29,12 +29,12 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt } - + /* - * This test is based on a chi-squared test for randomness. The values are hard-coded + * This test is based on a chi-squared test for randomness. The values are hard-coded * so as not to create Spark's dependency on apache.commons.math3 just to call one * method for calculating the exact p-value for a given number of random numbers - * and bins. In case one would want to move to a full-fledged test based on + * and bins. In case one would want to move to a full-fledged test based on * apache.commons.math3, the relevant class is here: * org.apache.commons.math3.stat.inference.ChiSquareTest */ @@ -49,19 +49,19 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { // populate bins based on modulus of the random number times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} - /* since the seed is deterministic, until the algorithm is changed, we know the result will be - * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, - * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) - * significance level. However, should the RNG implementation change, the test should still - * pass at the same significance level. The chi-squared test done in R gave the following + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following * results: * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, * 10000790, 10002286, 9998699)) * Chi-squared test for given probabilities - * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, * 10002286, 9998699) * X-squared = 11.975, df = 9, p-value = 0.2147 - * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million * random numbers * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared * is greater than or equal to that number. diff --git a/docs/_config.yml b/docs/_config.yml index bd5ed6c9220d2..d585b8c5ea763 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -7,6 +7,6 @@ SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" -MESOS_VERSION: 0.17.0 +MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/configuration.md b/docs/configuration.md index 9c602402f0635..f3bfd036f4164 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -190,6 +190,13 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. + + spark.ui.killEnabled + true + + Allows stages and corresponding jobs to be killed from the web ui. + + spark.shuffle.compress true diff --git a/docs/index.md b/docs/index.md index 7a13fa9a9a2b6..89ec5b05488a9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -67,8 +67,6 @@ In addition, if you wish to run Spark on [YARN](running-on-yarn.html), set Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`. -For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to build Spark and publish it locally. See [Launching Spark on YARN](running-on-yarn.html). This is needed because Hadoop 2.2 has non backwards compatible API changes. - # Where to Go from Here **Programming guides:** diff --git a/docs/tuning.md b/docs/tuning.md index 093df3187a789..cc069f0e84b9c 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -90,9 +90,10 @@ than the "raw" data inside their fields. This is due to several reasons: * Each distinct Java object has an "object header", which is about 16 bytes and contains information such as a pointer to its class. For an object with very little data in it (say one `Int` field), this can be bigger than the data. -* Java Strings have about 40 bytes of overhead over the raw string data (since they store it in an +* Java `String`s have about 40 bytes of overhead over the raw string data (since they store it in an array of `Char`s and keep extra data such as the length), and store each character - as *two* bytes due to Unicode. Thus a 10-character string can easily consume 60 bytes. + as *two* bytes due to `String`'s internal usage of UTF-16 encoding. Thus a 10-character string can + easily consume 60 bytes. * Common collection classes, such as `HashMap` and `LinkedList`, use linked data structures, where there is a "wrapper" object for each entry (e.g. `Map.Entry`). This object not only has a header, but also pointers (typically 8 bytes each) to the next object in the list. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index d8840c94ac17c..31209a662bbe1 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -70,7 +70,7 @@ def parse_args(): "slaves across multiple (an additional $0.01/Gb for bandwidth" + "between zones applies)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") - parser.add_option("-v", "--spark-version", default="0.9.0", + parser.add_option("-v", "--spark-version", default="0.9.1", help="Version of Spark to use: 'X.Y.Z' or a specific git hash") parser.add_option("--spark-git-repo", default="https://github.com/apache/spark", @@ -157,7 +157,7 @@ def is_active(instance): # Return correct versions of Spark and Shark, given the supplied Spark version def get_spark_shark_version(opts): - spark_shark_map = {"0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0"} + spark_shark_map = {"0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1"} version = opts.spark_version.replace("v", "") if version not in spark_shark_map: print >> stderr, "Don't know about Spark version: %s" % version diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 41e813d48c7b8..1204cfba39f77 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -48,41 +48,41 @@ import org.apache.spark.streaming.dstream._ * @param storageLevel RDD storage level. */ -private[streaming] +private[streaming] class MQTTInputDStream[T: ClassTag]( @transient ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_) with Logging { - + def getReceiver(): NetworkReceiver[T] = { new MQTTReceiver(brokerUrl, topic, storageLevel).asInstanceOf[NetworkReceiver[T]] } } -private[streaming] +private[streaming] class MQTTReceiver(brokerUrl: String, topic: String, storageLevel: StorageLevel ) extends NetworkReceiver[Any] { lazy protected val blockGenerator = new BlockGenerator(storageLevel) - + def onStop() { blockGenerator.stop() } - + def onStart() { blockGenerator.start() - // Set up persistence for messages + // Set up persistence for messages var peristance: MqttClientPersistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance var client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) - // Connect to MqttBroker + // Connect to MqttBroker client.connect() // Subscribe to Mqtt topic @@ -91,7 +91,7 @@ class MQTTReceiver(brokerUrl: String, // Callback automatically triggers as and when new message arrives on specified topic var callback: MqttCallback = new MqttCallback() { - // Handles Mqtt message + // Handles Mqtt message override def messageArrived(arg0: String, arg1: MqttMessage) { blockGenerator += new String(arg1.getPayload()) } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 3316b6dc39d6b..843a4a7a9ad72 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel * @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. -* +* * If no Authorization object is provided, initializes OAuth authorization using the system * properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. */ @@ -42,13 +42,13 @@ class TwitterInputDStream( filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - + private def createOAuthAuthorization(): Authorization = { new OAuthAuthorization(new ConfigurationBuilder().build()) } private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) - + override def getReceiver(): NetworkReceiver[Status] = { new TwitterReceiver(authorization, filters, storageLevel) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 377d9d6bd5e72..5635287694ee2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -172,7 +172,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali "EdgeDirection.Either instead.") } } - + /** * Join the vertices with an RDD and then apply a function from the * the vertex and RDD entry to a new vertex value. The input table diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 6386306c048fc..a467ca1ae715a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -55,7 +55,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } } - + test ("filter") { withSpark { sc => val n = 5 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala new file mode 100644 index 0000000000000..7858ec602483f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.rdd.RDDFunctions._ + +/** + * Computes the area under the curve (AUC) using the trapezoidal rule. + */ +private[evaluation] object AreaUnderCurve { + + /** + * Uses the trapezoidal rule to compute the area under the line connecting the two input points. + * @param points two 2D points stored in Seq + */ + private def trapezoid(points: Seq[(Double, Double)]): Double = { + require(points.length == 2) + val x = points.head + val y = points.last + (y._1 - x._1) * (y._2 + x._2) / 2.0 + } + + /** + * Returns the area under the given curve. + * + * @param curve a RDD of ordered 2D points stored in pairs representing a curve + */ + def of(curve: RDD[(Double, Double)]): Double = { + curve.sliding(2).aggregate(0.0)( + seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combOp = _ + _ + ) + } + + /** + * Returns the area under the given curve. + * + * @param curve an iterator over ordered 2D points stored in pairs representing a curve + */ + def of(curve: Iterable[(Double, Double)]): Double = { + curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)( + seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combop = _ + _ + ) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala new file mode 100644 index 0000000000000..562663ad36b40 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary classification evaluation metric computer. + */ +private[evaluation] trait BinaryClassificationMetricComputer extends Serializable { + def apply(c: BinaryConfusionMatrix): Double +} + +/** Precision. */ +private[evaluation] object Precision extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) +} + +/** False positive rate. */ +private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numFalsePositives.toDouble / c.numNegatives +} + +/** Recall. */ +private[evaluation] object Recall extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / c.numPositives +} + +/** + * F-Measure. + * @param beta the beta constant in F-Measure + * @see http://en.wikipedia.org/wiki/F1_score + */ +private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer { + private val beta2 = beta * beta + override def apply(c: BinaryConfusionMatrix): Double = { + val precision = Precision(c) + val recall = Recall(c) + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala new file mode 100644 index 0000000000000..ed7b0fc943367 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.apache.spark.rdd.{UnionRDD, RDD} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.AreaUnderCurve +import org.apache.spark.Logging + +/** + * Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]]. + * + * @param count label counter for labels with scores greater than or equal to the current score + * @param totalCount label counter for all labels + */ +private case class BinaryConfusionMatrixImpl( + count: LabelCounter, + totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { + + /** number of true positives */ + override def numTruePositives: Long = count.numPositives + + /** number of false positives */ + override def numFalsePositives: Long = count.numNegatives + + /** number of false negatives */ + override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives + + /** number of true negatives */ + override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives + + /** number of positives */ + override def numPositives: Long = totalCount.numPositives + + /** number of negatives */ + override def numNegatives: Long = totalCount.numNegatives +} + +/** + * Evaluator for binary classification. + * + * @param scoreAndLabels an RDD of (score, label) pairs. + */ +class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) + extends Serializable with Logging { + + private lazy val ( + cumulativeCounts: RDD[(Double, LabelCounter)], + confusions: RDD[(Double, BinaryConfusionMatrix)]) = { + // Create a bin for each distinct score value, count positives and negatives within each bin, + // and then sort by score values in descending order. + val counts = scoreAndLabels.combineByKey( + createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label, + mergeValue = (c: LabelCounter, label: Double) => c += label, + mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2 + ).sortByKey(ascending = false) + val agg = counts.values.mapPartitions({ iter => + val agg = new LabelCounter() + iter.foreach(agg += _) + Iterator(agg) + }, preservesPartitioning = true).collect() + val partitionwiseCumulativeCounts = + agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c) + val totalCount = partitionwiseCumulativeCounts.last + logInfo(s"Total counts: $totalCount") + val cumulativeCounts = counts.mapPartitionsWithIndex( + (index: Int, iter: Iterator[(Double, LabelCounter)]) => { + val cumCount = partitionwiseCumulativeCounts(index) + iter.map { case (score, c) => + cumCount += c + (score, cumCount.clone()) + } + }, preservesPartitioning = true) + cumulativeCounts.persist() + val confusions = cumulativeCounts.map { case (score, cumCount) => + (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) + } + (cumulativeCounts, confusions) + } + + /** Unpersist intermediate RDDs used in the computation. */ + def unpersist() { + cumulativeCounts.unpersist() + } + + /** Returns thresholds in descending order. */ + def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an RDD of (false positive rate, true positive rate) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + def roc(): RDD[(Double, Double)] = { + val rocCurve = createCurve(FalsePositiveRate, Recall) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) + } + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + def areaUnderROC(): Double = AreaUnderCurve.of(roc()) + + /** + * Returns the precision-recall curve, which is an RDD of (recall, precision), + * NOT (precision, recall), with (0.0, 1.0) prepended to it. + * @see http://en.wikipedia.org/wiki/Precision_and_recall + */ + def pr(): RDD[(Double, Double)] = { + val prCurve = createCurve(Recall, Precision) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) + first.union(prCurve) + } + + /** + * Computes the area under the precision-recall curve. + */ + def areaUnderPR(): Double = AreaUnderCurve.of(pr()) + + /** + * Returns the (threshold, F-Measure) curve. + * @param beta the beta factor in F-Measure computation. + * @return an RDD of (threshold, F-Measure) pairs. + * @see http://en.wikipedia.org/wiki/F1_score + */ + def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) + + /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) + + /** Returns the (threshold, precision) curve. */ + def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) + + /** Returns the (threshold, recall) curve. */ + def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) + + /** Creates a curve of (threshold, metric). */ + private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (s, c) => + (s, y(c)) + } + } + + /** Creates a curve of (metricX, metricY). */ + private def createCurve( + x: BinaryClassificationMetricComputer, + y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (_, c) => + (x(c), y(c)) + } + } +} + +/** + * A counter for positives and negatives. + * + * @param numPositives number of positive labels + * @param numNegatives number of negative labels + */ +private class LabelCounter( + var numPositives: Long = 0L, + var numNegatives: Long = 0L) extends Serializable { + + /** Processes a label. */ + def +=(label: Double): LabelCounter = { + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle + // -1.0 for negative as well. + if (label > 0.5) numPositives += 1L else numNegatives += 1L + this + } + + /** Merges another counter. */ + def +=(other: LabelCounter): LabelCounter = { + numPositives += other.numPositives + numNegatives += other.numNegatives + this + } + + override def clone: LabelCounter = { + new LabelCounter(numPositives, numNegatives) + } + + override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala new file mode 100644 index 0000000000000..75a75b216002a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary confusion matrix. + */ +private[evaluation] trait BinaryConfusionMatrix { + /** number of true positives */ + def numTruePositives: Long + + /** number of false positives */ + def numFalsePositives: Long + + /** number of false negatives */ + def numFalseNegatives: Long + + /** number of true negatives */ + def numTrueNegatives: Long + + /** number of positives */ + def numPositives: Long = numTruePositives + numFalseNegatives + + /** number of negatives */ + def numNegatives: Long = numFalsePositives + numTrueNegatives +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index f65f43dd3007b..0c0afcd9ec0d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import java.util -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd} +import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -27,6 +27,138 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +/** + * Column statistics aggregator implementing + * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]] + * together with add() and merge() function. + * A numerically stable algorithm is implemented to compute sample mean and variance: + *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. + * Zero elements (including explicit zero values) are skipped when calling add() and merge(), + * to have time complexity O(nnz) instead of O(n) for each column. + */ +private class ColumnStatisticsAggregator(private val n: Int) + extends MultivariateStatisticalSummary with Serializable { + + private val currMean: BDV[Double] = BDV.zeros[Double](n) + private val currM2n: BDV[Double] = BDV.zeros[Double](n) + private var totalCnt = 0.0 + private val nnz: BDV[Double] = BDV.zeros[Double](n) + private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue) + private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue) + + override def mean: Vector = { + val realMean = BDV.zeros[Double](n) + var i = 0 + while (i < n) { + realMean(i) = currMean(i) * nnz(i) / totalCnt + i += 1 + } + Vectors.fromBreeze(realMean) + } + + override def variance: Vector = { + val realVariance = BDV.zeros[Double](n) + + val denominator = totalCnt - 1.0 + + // Sample variance is computed, if the denominator is less than 0, the variance is just 0. + if (denominator > 0.0) { + val deltaMean = currMean + var i = 0 + while (i < currM2n.size) { + realVariance(i) = + currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt + realVariance(i) /= denominator + i += 1 + } + } + + Vectors.fromBreeze(realVariance) + } + + override def count: Long = totalCnt.toLong + + override def numNonzeros: Vector = Vectors.fromBreeze(nnz) + + override def max: Vector = { + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMax) + } + + override def min: Vector = { + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMin) + } + + /** + * Aggregates a row. + */ + def add(currData: BV[Double]): this.type = { + currData.activeIterator.foreach { + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val tmpPrevMean = currMean(i) + currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + + nnz(i) += 1.0 + } + + totalCnt += 1.0 + this + } + + /** + * Merges another aggregator. + */ + def merge(other: ColumnStatisticsAggregator): this.type = { + require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.") + + totalCnt += other.totalCnt + val deltaMean = currMean - other.currMean + + var i = 0 + while (i < n) { + // merge mean together + if (other.currMean(i) != 0.0) { + currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / + (nnz(i) + other.nnz(i)) + } + // merge m2n together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / + (nnz(i) + other.nnz(i)) + } + if (currMax(i) < other.currMax(i)) { + currMax(i) = other.currMax(i) + } + if (currMin(i) > other.currMin(i)) { + currMin(i) = other.currMin(i) + } + i += 1 + } + + nnz += other.nnz + this + } +} /** * :: Experimental :: @@ -182,13 +314,7 @@ class RowMatrix( combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) ) - // Update _m if it is not set, or verify its value. - if (nRows <= 0L) { - nRows = m - } else { - require(nRows == m, - s"The number of rows $m is different from what specified or previously computed: ${nRows}.") - } + updateNumRows(m) mean :/= m.toDouble @@ -240,6 +366,19 @@ class RowMatrix( } } + /** + * Computes column-wise summary statistics. + */ + def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { + val zeroValue = new ColumnStatisticsAggregator(numCols().toInt) + val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + ) + updateNumRows(summary.count) + summary + } + /** * Multiply this matrix by a local matrix on the right. * @@ -276,6 +415,16 @@ class RowMatrix( } mat } + + /** Updates or verfires the number of rows. */ + private def updateNumRows(m: Long) { + if (nRows <= 0) { + nRows == m + } else { + require(nRows == m, + s"The number of rows $m is different from what specified or previously computed: ${nRows}.") + } + } } object RowMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index e41d9bbe18c37..7f6d94571b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.Vector trait Optimizer extends Serializable { /** - * Solve the provided convex optimization problem. + * Solve the provided convex optimization problem. */ def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala new file mode 100644 index 0000000000000..873de871fd884 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD + +/** + * Machine learning specific RDD functions. + */ +private[mllib] +class RDDFunctions[T: ClassTag](self: RDD[T]) { + + /** + * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions and the window size is + * greater than 1. + */ + def sliding(windowSize: Int): RDD[Seq[T]] = { + require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") + if (windowSize == 1) { + self.map(Seq(_)) + } else { + new SlidingRDD[T](self, windowSize) + } + } +} + +private[mllib] +object RDDFunctions { + + /** Implicit conversion from an RDD to RDDFunctions. */ + implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala new file mode 100644 index 0000000000000..dd80782c0f001 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.rdd.RDD + +private[mllib] +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) + extends Partition with Serializable { + override val index: Int = idx +} + +/** + * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions. To make this operation + * efficient, the number of items per partition should be larger than the window size and the + * window size should be small, e.g., 2. + * + * @param parent the parent RDD + * @param windowSize the window size, must be greater than 1 + * + * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + */ +private[mllib] +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) + extends RDD[Seq[T]](parent) { + + require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + + override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + val part = split.asInstanceOf[SlidingRDDPartition[T]] + (firstParent[T].iterator(part.prev, context) ++ part.tail) + .sliding(windowSize) + .withPartial(false) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev) + + override def getPartitions: Array[Partition] = { + val parentPartitions = parent.partitions + val n = parentPartitions.size + if (n == 0) { + Array.empty + } else if (n == 1) { + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + } else { + val n1 = n - 1 + val w1 = windowSize - 1 + // Get the first w1 items of each partition, starting from the second partition. + val nextHeads = + parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true) + val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + var i = 0 + var partitionIndex = 0 + while (i < n1) { + var j = i + val tail = mutable.ListBuffer[T]() + // Keep appending to the current tail until appended a head of size w1. + while (j < n1 && nextHeads(j).size < w1) { + tail ++= nextHeads(j) + j += 1 + } + if (j < n1) { + tail ++= nextHeads(j) + j += 1 + } + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) + partitionIndex += 1 + // Skip appended heads. + i = j + } + // If the head of last partition has size w1, we also need to add this partition. + if (nextHeads.last.size == w1) { + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + } + partitions.toArray + } + } + + // TODO: Override methods such as aggregate, which only requires one Spark job. +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 3bd0017aa196a..d969e7aa60061 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** - * GeneralizedLinearModel (GLM) represents a model trained using + * GeneralizedLinearModel (GLM) represents a model trained using * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and * an intercept. * @@ -38,7 +38,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double /** * Predict the result given a data point and the weights learned. - * + * * @param dataMatrix Row vector containing the features for this data point * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala new file mode 100644 index 0000000000000..f9eb343da2b82 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.mllib.linalg.Vector + +/** + * Trait for multivariate statistical summary of a data matrix. + */ +trait MultivariateStatisticalSummary { + + /** + * Sample mean vector. + */ + def mean: Vector + + /** + * Sample variance vector. Should return a zero vector if the sample size is 1. + */ + def variance: Vector + + /** + * Sample size. + */ + def count: Long + + /** + * Number of nonzero elements (including explicitly presented zero values) in each column. + */ + def numNonzeros: Vector + + /** + * Maximum value of each column. + */ + def max: Vector + + /** + * Minimum value of each column. + */ + def min: Vector +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index ac2360c429e2b..901c3180eac4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.util -import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, - squaredDistance => breezeSquaredDistance} +import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vectors /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -158,58 +157,6 @@ object MLUtils { dataStr.saveAsTextFile(dir) } - /** - * Utility function to compute mean and standard deviation on a given dataset. - * - * @param data - input data set whose statistics are computed - * @param numFeatures - number of features - * @param numExamples - number of examples in input dataset - * - * @return (yMean, xColMean, xColSd) - Tuple consisting of - * yMean - mean of the labels - * xColMean - Row vector with mean for every column (or feature) of the input data - * xColSd - Row vector standard deviation for every column (or feature) of the input data. - */ - private[mllib] def computeStats( - data: RDD[LabeledPoint], - numFeatures: Int, - numExamples: Long): (Double, Vector, Vector) = { - val brzData = data.map { case LabeledPoint(label, features) => - (label, features.toBreeze) - } - val aggStats = brzData.aggregate( - (0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures)) - )( - seqOp = (c, v) => (c, v) match { - case ((n, sumLabel, sum, sumSq), (label, features)) => - features.activeIterator.foreach { case (i, x) => - sumSq(i) += x * x - } - (n + 1L, sumLabel + label, sum += features, sumSq) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) => - (n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2) - } - ) - val (nl, sumLabel, sum, sumSq) = aggStats - - require(nl > 0, "Input data is empty.") - require(nl == numExamples) - - val n = nl.toDouble - val yMean = sumLabel / n - val mean = sum / n - val std = new Array[Double](sum.length) - var i = 0 - while (i < numFeatures) { - std(i) = sumSq(i) / n - mean(i) * mean(i) - i += 1 - } - - (yMean, Vectors.fromBreeze(mean), Vectors.dense(std)) - } - /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala new file mode 100644 index 0000000000000..1c9844f289fe0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { + test("auc computation") { + val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) + val auc = 4.0 + assert(AreaUnderCurve.of(curve) === auc) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) == auc) + } + + test("auc of an empty curve") { + val curve = Seq.empty[(Double, Double)] + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } + + test("auc of a curve with a single point") { + val curve = Seq((1.0, 1.0)) + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala new file mode 100644 index 0000000000000..173fdaefab3da --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.evaluation.AreaUnderCurve + +class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + test("binary evaluation metrics") { + val scoreAndLabels = sc.parallelize( + Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = Seq(1, 3, 3, 4) + val numFalsePositives = Seq(0, 1, 2, 3) + val numPositives = 4 + val numNegatives = 3 + val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recall = numTruePositives.map(t => t.toDouble / numPositives) + val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) + val pr = recall.zip(precision) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + assert(metrics.thresholds().collect().toSeq === threshold) + assert(metrics.roc().collect().toSeq === rocCurve) + assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) + assert(metrics.pr().collect().toSeq === prCurve) + assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) + assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) + assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) + assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) + assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 71ee8e8a4f6fd..c9f9acf4c1335 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { )) } } + + test("compute column summary statistics") { + for (mat <- Seq(denseMat, sparseMat)) { + val summary = mat.computeColumnSummaryStatistics() + // Run twice to make sure no internal states are changed. + for (k <- 0 to 1) { + assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch") + assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch") + assert(summary.count === m, "count mismatch.") + assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") + assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") + assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala new file mode 100644 index 0000000000000..3f3b10dfff35e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.rdd.RDDFunctions._ + +class RDDFunctionsSuite extends FunSuite with LocalSparkContext { + + test("sliding") { + val data = 0 until 6 + for (numPartitions <- 1 to 8) { + val rdd = sc.parallelize(data, numPartitions) + for (windowSize <- 1 to 6) { + val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList + val expected = data.sliding(windowSize).map(_.toList).toList + assert(sliding === expected) + } + assert(rdd.sliding(7).collect().isEmpty, + "Should return an empty RDD if the window size is greater than the number of items.") + } + } + + test("sliding with empty partitions") { + val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) + val rdd = sc.parallelize(data, data.length).flatMap(s => s) + assert(rdd.partitions.size === data.length) + val sliding = rdd.sliding(3) + val expected = data.flatMap(x => x).sliding(3).toList + assert(sliding.collect().toList === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index e451c350b8d88..812a8434784be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -27,7 +27,6 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ class MLUtilsSuite extends FunSuite with LocalSparkContext { @@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { } } - test("compute stats") { - val data = Seq.fill(3)(Seq( - LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)), - LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0)) - )).flatten - val rdd = sc.parallelize(data, 2) - val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6) - assert(meanLabel === 0.5) - assert(mean === Vectors.dense(2.0, 3.0, 4.0)) - assert(std === Vectors.dense(1.0, 1.0, 1.0)) - } - test("loadLibSVMData") { val lines = """ diff --git a/pom.xml b/pom.xml index 11511bcb9da52..c03bb35c99442 100644 --- a/pom.xml +++ b/pom.xml @@ -112,7 +112,7 @@ 2.10.4 2.10 - 0.17.0 + 0.13.0 org.spark-project.akka 2.2.3-shaded-protobuf 1.7.5 @@ -848,7 +848,7 @@ - + hadoop-provided @@ -893,6 +893,6 @@ - + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 694f90a83ab67..21163760e6277 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -333,7 +333,7 @@ object SparkBuild extends Build { "org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap), "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", - "org.apache.mesos" % "mesos" % "0.17.0", + "org.apache.mesos" % "mesos" % "0.13.0", "commons-net" % "commons-net" % "2.2", "net.java.dev.jets3t" % "jets3t" % "0.7.1" excludeAll(excludeCommonsLogging), "org.apache.derby" % "derby" % "10.4.2.0" % "test", diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 35e48276e3cb9..61613dbed8dce 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -29,6 +29,9 @@ # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None +if os.environ.get("SPARK_EXECUTOR_URI"): + SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) + sc = SparkContext(os.environ.get("MASTER", "local[*]"), "PySparkShell", pyFiles=add_files) print """Welcome to diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index a30dcfdcecf27..687e85ca94d3c 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -35,7 +35,7 @@ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. * Allows the user to specify if user class path should be first - */ + */ class ExecutorClassLoader(classUri: String, parent: ClassLoader, userClassPathFirst: Boolean) extends ClassLoader { val uri = new URI(classUri) @@ -94,7 +94,7 @@ class ExecutorClassLoader(classUri: String, parent: ClassLoader, case e: Exception => None } } - + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { if (name.startsWith("line") && name.endsWith("$iw$")) { // Class seems to be an interpreter "wrapper" object storing a val or var. diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 8f61a5e835044..419796b68b113 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -187,7 +187,7 @@ trait SparkImports { if (currentImps contains imv) addWrapper() val objName = req.lineRep.readPath val valName = "$VAL" + newValId(); - + if(!code.toString.endsWith(".`" + imv + "`;\n")) { // Which means already imported code.append("val " + valName + " = " + objName + ".INSTANCE;\n") code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 17118499d0c87..1f3fab09e9566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -28,7 +28,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { override def toString = s"CAST($child, $dataType)" type EvaluatedType = Any - + def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { null } else { @@ -40,7 +40,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) case _ => nullOrCast[Any](_, _.toString) } - + // BinaryConverter def castToBinary: Any => Any = child.dataType match { case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) @@ -58,7 +58,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case DoubleType => nullOrCast[Double](_, _ != 0) case FloatType => nullOrCast[Float](_, _ != 0) } - + // TimestampConverter def castToTimestamp: Any => Any = child.dataType match { case StringType => nullOrCast[String](_, s => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8a1db8e796816..dd9332ada80dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,7 +86,7 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed + * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed * to be in the same data type, and also the return type. * Either one of the expressions result is null, the evaluation result should be null. */ @@ -120,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Evaluation helper function for 2 Fractional children expressions. Those expressions are + * Evaluation helper function for 2 Fractional children expressions. Those expressions are * supposed to be in the same data type, and also the return type. * Either one of the expressions result is null, the evaluation result should be null. */ @@ -153,7 +153,7 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Evaluation helper function for 2 Integral children expressions. Those expressions are + * Evaluation helper function for 2 Integral children expressions. Those expressions are * supposed to be in the same data type, and also the return type. * Either one of the expressions result is null, the evaluation result should be null. */ @@ -186,12 +186,12 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Evaluation helper function for 2 Comparable children expressions. Those expressions are + * Evaluation helper function for 2 Comparable children expressions. Those expressions are * supposed to be in the same data type, and the return type should be Integer: * Negative value: 1st argument less than 2nd argument * Zero: 1st argument equals 2nd argument * Positive value: 1st argument greater than 2nd argument - * + * * Either one of the expressions result is null, the evaluation result should be null. */ @inline @@ -213,7 +213,7 @@ abstract class Expression extends TreeNode[Expression] { null } else { e1.dataType match { - case i: NativeType => + case i: NativeType => f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) case other => sys.error(s"Type $other does not support ordered operations") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a27c71db1b999..ddc16ce87b895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -28,19 +28,19 @@ trait StringRegexExpression { self: BinaryExpression => type EvaluatedType = Any - + def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - + def nullable: Boolean = true def dataType: DataType = BooleanType - - // try cache the pattern for Literal + + // try cache the pattern for Literal private lazy val cache: Pattern = right match { case x @ Literal(value: String, StringType) => compile(value) case _ => null } - + protected def compile(str: String): Pattern = if(str == null) { null } else { @@ -49,7 +49,7 @@ trait StringRegexExpression { } protected def pattern(str: String) = if(cache == null) compile(str) else cache - + override def eval(input: Row): Any = { val l = left.eval(input) if (l == null) { @@ -73,11 +73,11 @@ trait StringRegexExpression { /** * Simple RegEx pattern matching function */ -case class Like(left: Expression, right: Expression) +case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - + def symbol = "LIKE" - + // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String) = { @@ -98,19 +98,19 @@ case class Like(left: Expression, right: Expression) sb.append(Pattern.quote(Character.toString(n))); } } - + i += 1 } - + sb.toString() } - + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() } -case class RLike(left: Expression, right: Expression) +case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - + def symbol = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 37b23ba58289c..c0a09a16ac98d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -33,7 +33,56 @@ object Optimizer extends RuleExecutor[LogicalPlan] { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughInnerJoin) :: Nil + PushPredicateThroughInnerJoin, + ColumnPruning) :: Nil +} + +/** + * Attempts to eliminate the reading of unneeded columns from the query plan using the following + * transformations: + * + * - Inserting Projections beneath the following operators: + * - Aggregate + * - Project <- Join + * - Collapse adjacent projections, performing alias substitution. + */ +object ColumnPruning extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + // Project away references that are not needed to calculate the required aggregates. + a.copy(child = Project(a.references.toSeq, child)) + + case Project(projectList, Join(left, right, joinType, condition)) => + // Collect the list of off references required either above or to evaluate the condition. + val allReferences: Set[Attribute] = + projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + /** Applies a projection when the child is producing unnecessary attributes */ + def prunedChild(c: LogicalPlan) = + if ((allReferences.filter(c.outputSet.contains) -- c.outputSet).nonEmpty) { + Project(allReferences.filter(c.outputSet.contains).toSeq, c) + } else { + c + } + + Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition)) + + case Project(projectList1, Project(projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = projectList2.collect { + case a @ Alias(e, _) => (a.toAttribute: Expression, a) + }.toMap + + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a if aliasMap.contains(a) => aliasMap(a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index cfc0b0c3a8d98..397473e178867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -127,7 +127,7 @@ case class Aggregate( extends UnaryNode { def output = aggregateExpressions.map(_.toAttribute) - def references = child.references + def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index cdeb01a9656f4..da34bd3a21503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -55,9 +55,9 @@ case object BooleanType extends NativeType { case object TimestampType extends NativeType { type JvmType = Timestamp - + @transient lazy val tag = typeTag[JvmType] - + val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 888a19d79f7e4..2cd0d2b0e1385 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -144,7 +144,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("abc" like "b%", false) checkEvaluation("abc" like "bc%", false) } - + test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) @@ -164,7 +164,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("RLIKE literal Regular Expression") { checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) - + checkEvaluation("fofo" rlike "^fo", true) checkEvaluation("fo\no" rlike "^fo\no$", true) checkEvaluation("Bn" rlike "^Ba*n", true) @@ -196,9 +196,9 @@ class ExpressionEvaluationSuite extends FunSuite { evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) } } - + test("data type casting") { - + val sts = "1970-01-01 00:00:01.0" val ts = Timestamp.valueOf(sts) @@ -236,7 +236,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast ShortType, 23) checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) - + intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 65eae3357a21e..1cbf973c34917 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -56,4 +56,4 @@ class ScalaReflectionRelationSuite extends FunSuite { val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 93023e8dced57..ac56ff709c1c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -59,7 +59,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } -private[streaming] +private[streaming] object Checkpoint extends Logging { val PREFIX = "checkpoint-" val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r @@ -79,7 +79,7 @@ object Checkpoint extends Logging { def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } - (time1 < time2) || (time1 == time2 && bk1) + (time1 < time2) || (time1 == time2 && bk1) } val path = new Path(checkpointDir) @@ -95,7 +95,7 @@ object Checkpoint extends Logging { } } else { logInfo("Checkpoint directory " + path + " does not exist") - Seq.empty + Seq.empty } } } @@ -160,7 +160,7 @@ class CheckpointWriter( }) } - // All done, print success + // All done, print success val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") @@ -227,14 +227,14 @@ object CheckpointReader extends Logging { { val checkpointPath = new Path(checkpointDir) def fs = checkpointPath.getFileSystem(hadoopConf) - - // Try to find the checkpoint files + + // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse if (checkpointFiles.isEmpty) { return None } - // Try to read the checkpoint files in the order + // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) val compressionCodec = CompressionCodec.createCodec(conf) checkpointFiles.foreach(file => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index 16479a01272aa..ad4f3fdd14ad6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -20,11 +20,11 @@ package org.apache.spark.streaming private[streaming] class Interval(val beginTime: Time, val endTime: Time) { def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs)) - + def duration(): Duration = endTime - beginTime def + (time: Duration): Interval = { - new Interval(beginTime + time, endTime + time) + new Interval(beginTime + time, endTime + time) } def - (time: Duration): Interval = { @@ -40,9 +40,9 @@ class Interval(val beginTime: Time, val endTime: Time) { } def <= (that: Interval) = (this < that || this == that) - + def > (that: Interval) = !(this <= that) - + def >= (that: Interval) = !(this < that) override def toString = "[" + beginTime + ", " + endTime + "]" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 2678334f53844..6a6b00a778b48 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -32,7 +32,7 @@ case class Time(private val millis: Long) { def <= (that: Time): Boolean = (this.millis <= that.millis) def > (that: Time): Boolean = (this.millis > that.millis) - + def >= (that: Time): Boolean = (this.millis >= that.millis) def + (that: Duration): Time = new Time(millis + that.milliseconds) @@ -43,7 +43,7 @@ case class Time(private val millis: Long) { def floor(that: Duration): Time = { val t = that.milliseconds - val m = math.floor(this.millis / t).toLong + val m = math.floor(this.millis / t).toLong new Time(m * t) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 903e3f3c9b713..f33c0ceafdf42 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -51,7 +51,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) .map(x => (x._1, x._2.getCheckpointFile.get)) logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n")) - // Add the checkpoint files to the data to be serialized + // Add the checkpoint files to the data to be serialized if (!checkpointFiles.isEmpty) { currentCheckpointFiles.clear() currentCheckpointFiles ++= checkpointFiles diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 8a6051622e2d5..e878285f6a854 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -232,7 +232,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } logDebug("Accepted " + path) } catch { - case fnfe: java.io.FileNotFoundException => + case fnfe: java.io.FileNotFoundException => logWarning("Error finding new files", fnfe) reset() return false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index 97325f8ea3117..6376cff78b78a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -31,11 +31,11 @@ class QueueInputDStream[T: ClassTag]( oneAtATime: Boolean, defaultRDD: RDD[T] ) extends InputDStream[T](ssc) { - + override def start() { } - + override def stop() { } - + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { @@ -55,5 +55,5 @@ class QueueInputDStream[T: ClassTag]( None } } - + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 24289b714f99e..775b6bfd065c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -32,7 +32,7 @@ class WindowedDStream[T: ClassTag]( extends DStream[T](parent.ssc) { if (!_windowDuration.isMultipleOf(parent.slideDuration)) { - throw new Exception("The window duration of windowed DStream (" + _slideDuration + ") " + + throw new Exception("The window duration of windowed DStream (" + _windowDuration + ") " + "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index 44eb2750c6c7a..f5984d03c5342 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -47,7 +47,7 @@ object ReceiverSupervisorStrategy { * the API for pushing received data into Spark Streaming for being processed. * * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * + * * @example {{{ * class MyActor extends Actor with Receiver{ * def receive { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index c5ef2cc8c390d..39145a3ab081a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -19,34 +19,34 @@ package org.apache.spark.streaming.util private[streaming] trait Clock { - def currentTime(): Long + def currentTime(): Long def waitTillTime(targetTime: Long): Long } private[streaming] class SystemClock() extends Clock { - + val minPollTime = 25L - + def currentTime(): Long = { System.currentTimeMillis() - } - + } + def waitTillTime(targetTime: Long): Long = { var currentTime = 0L currentTime = System.currentTimeMillis() - + var waitTime = targetTime - currentTime if (waitTime <= 0) { return currentTime } - + val pollTime = { if (waitTime / 10.0 > minPollTime) { (waitTime / 10.0).toLong } else { - minPollTime - } + minPollTime + } } while (true) { @@ -55,7 +55,7 @@ class SystemClock() extends Clock { if (waitTime <= 0) { return currentTime } - val sleepTime = + val sleepTime = if (waitTime < pollTime) { waitTime } else { @@ -69,7 +69,7 @@ class SystemClock() extends Clock { private[streaming] class ManualClock() extends Clock { - + var time = 0L def currentTime() = time @@ -85,13 +85,13 @@ class ManualClock() extends Clock { this.synchronized { time += timeToAdd this.notifyAll() - } + } } def waitTillTime(targetTime: Long): Long = { this.synchronized { while (time < targetTime) { this.wait(100) - } + } } currentTime() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 07021ebb5802a..bd1df55cf70f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -25,8 +25,8 @@ import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { - /** - * Splits lines and counts the words in them using specialized object-to-long hashmap + /** + * Splits lines and counts the words in them using specialized object-to-long hashmap * (to avoid boxing-unboxing overhead of Long in java/scala HashMap) */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { @@ -55,13 +55,13 @@ object RawTextHelper { map.toIterator.map{case (k, v) => (k, v)} } - /** + /** * Gets the top k words in terms of word counts. Assumes that each word exists only once * in the `data` iterator (that is, the counts have been reduced). */ def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) - + var i = 0 var len = 0 var done = false @@ -93,7 +93,7 @@ object RawTextHelper { } taken.toIterator } - + /** * Warms up the SparkContext in master and slave by running tasks to force JIT kick in * before real workload starts. @@ -106,11 +106,11 @@ object RawTextHelper { .count() } } - - def add(v1: Long, v2: Long) = (v1 + v2) - def subtract(v1: Long, v2: Long) = (v1 - v2) + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long) = math.max(v1, v2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index f71938ac55ccb..e016377c94c0d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -22,10 +22,10 @@ import org.apache.spark.Logging private[streaming] class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) extends Logging { - + private val thread = new Thread("RecurringTimer - " + name) { setDaemon(true) - override def run() { loop } + override def run() { loop } } @volatile private var prevTime = -1L @@ -104,11 +104,11 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: private[streaming] object RecurringTimer { - + def main(args: Array[String]) { var lastRecurTime = 0L val period = 1000 - + def onRecur(time: Long) { val currentTime = System.currentTimeMillis() println("" + currentTime + ": " + (currentTime - lastRecurTime)) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 13fa64894b773..a0b1bbc34fa7c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1673,7 +1673,7 @@ public void testSocketTextStream() { @Test public void testSocketString() { - + class Converter implements Function> { public Iterable call(InputStream in) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(in)); diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 910484ed5432a..67ec95c8fc04f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -234,7 +234,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, assert(sparkContext != null || count >= numTries) if (null != sparkContext) { - uiAddress = sparkContext.ui.appUIAddress + uiAddress = sparkContext.ui.appUIHostPort this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, resourceManager, diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c8a4d2e647cbd..61af0f9ac5ca0 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -220,7 +220,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, assert(sparkContext != null || numTries >= maxNumTries) if (sparkContext != null) { - uiAddress = sparkContext.ui.appUIAddress + uiAddress = sparkContext.ui.appUIHostPort this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, amClient,