Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into cv_min
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 20, 2015
2 parents 16e3b2c + 1b6fe9b commit d632135
Show file tree
Hide file tree
Showing 121 changed files with 2,093 additions and 230 deletions.
5 changes: 5 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")

# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
"cancelJobGroup")

exportClasses("DataFrame")

exportMethods("arrange",
Expand Down
44 changes: 44 additions & 0 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,47 @@ sparkRHive.init <- function(jsc = NULL) {
assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
hiveCtx
}

#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a
#' different value or cleared.
#'
#' @param sc existing spark context
#' @param groupid the ID to be assigned to job groups
#' @param description description for the the job group ID
#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE)
#'}

setJobGroup <- function(sc, groupId, description, interruptOnCancel) {
callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)
}

#' Clear current job group ID and its description
#'
#' @param sc existing spark context
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' clearJobGroup(sc)
#'}

clearJobGroup <- function(sc) {
callJMethod(sc, "clearJobGroup")
}

#' Cancel active jobs for the specified group
#'
#' @param sc existing spark context
#' @param groupId the ID of job group to be cancelled
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' cancelJobGroup(sc, "myJobGroup")
#'}

cancelJobGroup <- function(sc, groupId) {
callJMethod(sc, "cancelJobGroup", groupId)
}
7 changes: 7 additions & 0 deletions R/pkg/inst/tests/test_context.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", {
count(rdd3)
count(rdd4)
})

test_that("job group functions can be called", {
sc <- sparkR.init()
setJobGroup(sc, "groupId", "job description", TRUE)
cancelJobGroup(sc, "groupId")
clearJobGroup(sc)
})
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class TaskMetrics extends Serializable {
*/
private var _diskBytesSpilled: Long = _
def diskBytesSpilled: Long = _diskBytesSpilled
def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value

/**
* If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
import java.io._
import java.lang.reflect.{Field, Method}
import java.security.AccessController

Expand Down Expand Up @@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging {
*
* It does not yet handle writeObject override, but that shouldn't be too hard to do either.
*/
def find(obj: Any): List[String] = {
private[serializer] def find(obj: Any): List[String] = {
new SerializationDebugger().visit(obj, List.empty)
}

Expand Down Expand Up @@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging {
return List.empty
}

/**
* Visit an externalizable object.
* Since writeExternal() can choose to add arbitrary objects at the time of serialization,
* the only way to capture all the objects it will serialize is by using a
* dummy ObjectOutput that collects all the relevant objects for further testing.
*/
private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
{
val fieldList = new ListObjectOutput
Expand All @@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging {
// An object contains multiple slots in serialization.
// Get the slots and visit fields in all of them.
val (finalObj, desc) = findObjectAndDescriptor(o)

// If the object has been replaced using writeReplace(),
// then call visit() on it again to test its type again.
if (!finalObj.eq(o)) {
return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
}

// Every class is associated with one or more "slots", each slot refers to the parent
// classes of this class. These slots are used by the ObjectOutputStream
// serialization code to recursively serialize the fields of an object and
// its parent classes. For example, if there are the following classes.
//
// class ParentClass(parentField: Int)
// class ChildClass(childField: Int) extends ParentClass(1)
//
// Then serializing the an object Obj of type ChildClass requires first serializing the fields
// of ParentClass (that is, parentField), and then serializing the fields of ChildClass
// (that is, childField). Correspondingly, there will be two slots related to this object:
//
// 1. ParentClass slot, which will be used to serialize parentField of Obj
// 2. ChildClass slot, which will be used to serialize childField fields of Obj
//
// The following code uses the description of each slot to find the fields in the
// corresponding object to visit.
//
val slotDescs = desc.getSlotDescs
var i = 0
while (i < slotDescs.length) {
val slotDesc = slotDescs(i)
if (slotDesc.hasWriteObjectMethod) {
// TODO: Handle classes that specify writeObject method.
// If the class type corresponding to current slot has writeObject() defined,
// then its not obvious which fields of the class will be serialized as the writeObject()
// can choose arbitrary fields for serialization. This case is handled separately.
val elem = s"writeObject data (class: ${slotDesc.getName})"
val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack)
if (childStack.nonEmpty) {
return childStack
}
} else {
// Visit all the fields objects of the class corresponding to the current slot.
val fields: Array[ObjectStreamField] = slotDesc.getFields
val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
val numPrims = fields.length - objFieldValues.length
desc.getObjFieldValues(finalObj, objFieldValues)
slotDesc.getObjFieldValues(finalObj, objFieldValues)

var j = 0
while (j < objFieldValues.length) {
Expand All @@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging {
}
j += 1
}

}
i += 1
}
return List.empty
}

/**
* Visit a serializable object which has the writeObject() defined.
* Since writeObject() can choose to add arbitrary objects at the time of serialization,
* the only way to capture all the objects it will serialize is by using a
* dummy ObjectOutputStream that collects all the relevant fields for further testing.
* This is similar to how externalizable objects are visited.
*/
private def visitSerializableWithWriteObjectMethod(
o: Object, stack: List[String]): List[String] = {
val innerObjectsCatcher = new ListObjectOutputStream
var notSerializableFound = false
try {
innerObjectsCatcher.writeObject(o)
} catch {
case io: IOException =>
notSerializableFound = true
}

// If something was not serializable, then visit the captured objects.
// Otherwise, all the captured objects are safely serializable, so no need to visit them.
// As an optimization, just added them to the visited list.
if (notSerializableFound) {
val innerObjects = innerObjectsCatcher.outputArray
var k = 0
while (k < innerObjects.length) {
val childStack = visit(innerObjects(k), stack)
if (childStack.nonEmpty) {
return childStack
}
k += 1
}
} else {
visited ++= innerObjectsCatcher.outputArray
}
return List.empty
}
}

/**
* Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
* writeReplace in Serializable. It starts with the object itself, and keeps calling the
* writeReplace method until there is no more
* writeReplace method until there is no more.
*/
@tailrec
private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
Expand Down Expand Up @@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging {
override def writeByte(i: Int): Unit = {}
}

/** An output stream that emulates /dev/null */
private class NullOutputStream extends OutputStream {
override def write(b: Int) { }
}

/**
* A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns
* them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()`
* method which gets called on every object, only if replacing is enabled. So this subclass
* of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that
* are being serializabled. The serialized bytes are ignored by sending them to a
* [[NullOutputStream]], which acts like a /dev/null.
*/
private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) {
private val output = new mutable.ArrayBuffer[Any]
this.enableReplaceObject(true)

def outputArray: Array[Any] = output.toArray

override def replaceObject(obj: Object): Object = {
output += obj
obj
}
}

/** An implicit class that allows us to call private methods of ObjectStreamClass. */
implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
def getSlotDescs: Array[ObjectStreamClass] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi
* size, which is guaranteed to explore all spaces for each key (see
* http://en.wikipedia.org/wiki/Quadratic_probing).
*
* The map can support up to `536870912 (2 ^ 29)` elements.
* The map can support up to `375809638 (0.7 * 2 ^ 29)` elements.
*
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
Expand Down Expand Up @@ -199,11 +199,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)

/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
if (curSize == MAXIMUM_CAPACITY) {
throw new IllegalStateException(s"Can't put more that ${MAXIMUM_CAPACITY} elements")
}
curSize += 1
if (curSize > growThreshold && capacity < MAXIMUM_CAPACITY) {
if (curSize > growThreshold) {
growTable()
}
}
Expand All @@ -216,7 +213,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
/** Double the table's size and re-hash everything */
protected def growTable() {
// capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow
val newCapacity = (capacity * 2).min(MAXIMUM_CAPACITY)
val newCapacity = capacity * 2
require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
val newData = new Array[AnyRef](2 * newCapacity)
val newMask = newCapacity - 1
// Insert all our old values into the new array. Note that because our old keys are
Expand Down
17 changes: 12 additions & 5 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class SparkSubmitSuite
runSparkSubmit(args)
}

ignore("includes jars passed in through --jars") {
test("includes jars passed in through --jars") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
Expand All @@ -340,7 +340,7 @@ class SparkSubmitSuite
}

// SPARK-7287
ignore("includes jars passed in through --packages") {
test("includes jars passed in through --packages") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
val dep = MavenCoordinate("my.great.dep", "mylib", "0.1")
Expand Down Expand Up @@ -499,9 +499,16 @@ class SparkSubmitSuite
Seq("./bin/spark-submit") ++ args,
new File(sparkHome),
Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
failAfter(60 seconds) { process.waitFor() }
// Ensure we still kill the process in case it timed out
process.destroy()

try {
val exitCode = failAfter(60 seconds) { process.waitFor() }
if (exitCode != 0) {
fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.")
}
} finally {
// Ensure we still kill the process in case it timed out
process.destroy()
}
}

private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = {
Expand Down
Loading

0 comments on commit d632135

Please sign in to comment.