diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 61d09d73e17cb..a5bb2f15bdcc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -18,9 +18,12 @@ package org.apache.spark.scheduler import java.io.NotSerializableException +import java.nio.ByteBuffer +import java.util import java.util.Properties import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} import scala.concurrent.Await import scala.concurrent.duration._ @@ -39,7 +42,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.util.{CallSite, Clock, RDDTrace, SerializationHelper, SystemClock, Utils} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -789,6 +792,44 @@ class DAGScheduler( } } + /** + * Helper function to check whether an RDD and its dependencies are serializable. + * + * This hook is exposed here primarily for testing purposes. + * + * Note: This function is defined separately from the SerializationHelper.isSerializable() + * since DAGScheduler.isSerializable() is passed as a parameter to the RDDWalker class's graph + * traversal, which would otherwise require knowledge of the closureSerializer + * (which was undesirable). + * + * @param rdd - Rdd to attempt to serialize + * @return Array[SerializedRdd] - + * Return an array of Either objects indicating if serialization is successful. + * Each object represents the RDD or a dependency of the RDD + * Success: ByteBuffer - The serialized RDD + * Failure: String - The reason for the failure. + * + */ + private[spark] def tryToSerializeRddDeps(rdd: RDD[_]): Array[RDDTrace] = { + SerializationHelper.tryToSerializeRddAndDeps(closureSerializer, rdd) + } + + + /** + * Returns nicely formatted text representing the trace of the failed serialization + * + * Note: This is defined here since it uses the closure serializer. Although the better place for + * the serializer would be in the SerializationHelper, the Helper is not guaranteed to run in a + * single thread unlike the DAGScheduler. + * + * @param rdd - The top-level reference that we are attempting to serialize + * @return + */ + def traceBrokenRdd(rdd: RDD[_]): String = { + SerializationHelper.tryToSerializeRdd(closureSerializer, rdd) + .fold(l => l, r => "Successfully serialized " + rdd.toString) + } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") @@ -827,9 +868,11 @@ class DAGScheduler( // might modify state of objects referenced in their closures. This is necessary in Hadoop // where the JobConf/Configuration object is not thread-safe. var taskBinary: Broadcast[Array[Byte]] = null + try { - // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = if (stage.isShuffleMap) { closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() @@ -840,10 +883,18 @@ class DAGScheduler( } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => + SerializationHelper + .tryToSerializeRdd(closureSerializer, stage.rdd) + .fold(l => logDebug(l), r => {}) + abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage return case NonFatal(e) => + SerializationHelper + .tryToSerializeRdd(closureSerializer, stage.rdd) + .fold(l => logDebug(l), r => {}) + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") runningStages -= stage return diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 4667850917151..a1b2beec45160 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -30,7 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.TaskState.TaskState -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{Clock, SerializationHelper, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -459,17 +459,29 @@ private[spark] class TaskSetManager( } // Serialize and return the task val startTime = clock.getTime() + val serializedTask: ByteBuffer = try { + // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here + // we assume the task can be serialized without exceptions. + Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) } catch { // If the task cannot be serialized, then there's no point to re-attempt the task, - // as it will always fail. So just abort the whole task-set. + // as it will always fail. So just abort the whole task-set and print a serialization + // trace to help identify the failure point. case NonFatal(e) => + SerializationHelper.tryToSerialize(ser, task).fold ( + l => logDebug("Un-serializable reference trace for " + + task.toString + ":\n" + l), + r => {} + ) + val msg = s"Failed to serialize task $taskId, not attempting to retry it." logError(msg, e) abort(s"$msg Exception during serialization: $e") throw new TaskNotSerializableException(e) } + if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true diff --git a/core/src/main/scala/org/apache/spark/util/ObjectWalker.scala b/core/src/main/scala/org/apache/spark/util/ObjectWalker.scala new file mode 100644 index 0000000000000..f4c674c5e17bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ObjectWalker.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.lang.reflect.{Modifier, Field} + +import scala.collection.mutable + + +/** + * This class permits traversing a generic Object's reference graph. This is useful for debugging + * serialization errors. See SPARK-3694. + * + * This code is based on code written by Josh Rosen found here: + * https://gist.github.com/JoshRosen/d6a8972c99992e97d040 + */ +private[spark] object ObjectWalker { + def isTransient(field: Field): Boolean = Modifier.isTransient(field.getModifiers) + def isStatic(field: Field): Boolean = Modifier.isStatic(field.getModifiers) + def isPrimitive(field: Field): Boolean = field.getType.isPrimitive + + /** + * Traverse the graph representing all references between the provided root object, its + * members, and their references in turn. + * + * What we want to be able to do is readily identify un-serializable components AND the path + * to those components. To do this, store the traversal of the graph as a 2-tuple - the actual + * reference visited and its parent. Then, to get the path to the un-serializable reference + * we can simply follow the parent links. + * + * @param rootObj - The root object for which to generate the reference graph + * @return a new Set containing the 2-tuple of references from the traversal of the + * reference graph along with their parent references. (self, parent) + */ + def buildRefGraph(rootObj: AnyRef): mutable.LinkedList[AnyRef] = { + val visitedRefs = mutable.Set[AnyRef]() + val toVisit = new mutable.Queue[AnyRef]() + var results = mutable.LinkedList[AnyRef]() + + toVisit += rootObj + + while (toVisit.nonEmpty) { + val obj : AnyRef = toVisit.dequeue() + // Store the last parent reference to enable quick retrieval of the path to a broken node + + if (!visitedRefs.contains(obj)) { + results = mutable.LinkedList(obj).append(results) + visitedRefs.add(obj) + + // Extract all the fields from the object that would be serialized. Transient and + // static references are not serialized and primitive variables will always be serializable + // and will not contain further references. + for (field <- getFieldsToTest(obj)) { + // Extract the field object and pass to the visitor + val originalAccessibility = field.isAccessible + field.setAccessible(true) + val fieldObj = field.get(obj) + field.setAccessible(originalAccessibility) + + if (fieldObj != null) { + toVisit += fieldObj + } + } + } + } + results + } + + /** + * Get the serialiazble fields from an object reference + * @param obj - Reference to the object fo rwhich to generate a serialization trace + * @return a new Set containing the serializable fields of the object + */ + def getFieldsToTest(obj: AnyRef): mutable.Set[Field] = { + getAllFields(obj.getClass) + .filterNot(isStatic) + .filterNot(isTransient) + .filterNot(isPrimitive) + } + + /** + * Get all fields (including private ones) from this class and its superclasses. + * @param cls - The class from which to retrieve fields + * @return a new mutable.Set representing the fields of the reference + */ + private def getAllFields(cls: Class[_]): mutable.Set[Field] = { + val fields = mutable.Set[Field]() + var _cls: Class[_] = cls + while (_cls != null) { + fields ++= _cls.getDeclaredFields + fields ++= _cls.getFields + _cls = _cls.getSuperclass + } + + fields + } +} diff --git a/core/src/main/scala/org/apache/spark/util/RDDWalker.scala b/core/src/main/scala/org/apache/spark/util/RDDWalker.scala new file mode 100644 index 0000000000000..c8b96f0cd1ab8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RDDWalker.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.language.existentials + +import org.apache.spark.rdd.RDD + +/** + * This class permits traversing the RDD's dependency graph. This is + * accomplished by walking the object graph linking these RDDs. This is useful for debugging + * internal RDD references. See SPARK-3694. + */ +private[spark] object RDDWalker { + /** + * Traverse the dependencies of the RDD and store them within an Array along with their depths. + * Return this data structure and subsequently process it. + * + * @param rddToWalk - The RDD to traverse along with its dependencies + * @return Array[(RDD[_], depth : Int] - An array of results generated by the traversal function + */ + def walk(rddToWalk : RDD[_]): Array[(RDD[_], Int)] = { + + val walkQueue = new mutable.Queue[(RDD[_], Int)]() + val visited = mutable.Set[RDD[_]]() + + // Keep track of both the RDD and its depth in the traversal graph. + val results = new ArrayBuffer[(RDD[_], Int)]() + // Implement as a queue to perform a BFS + walkQueue += ((rddToWalk,0)) + + while (!walkQueue.isEmpty) { + // Pop from the queue + val (rddToProcess : RDD[_], depth:Int) = walkQueue.dequeue() + if (!visited.contains(rddToProcess)) { + visited.add(rddToProcess) + rddToProcess.dependencies.foreach(s => walkQueue += ((s.rdd, depth + 1))) + results.append((rddToProcess, depth)) + } + } + + results.toArray + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializationHelper.scala b/core/src/main/scala/org/apache/spark/util/SerializationHelper.scala new file mode 100644 index 0000000000000..1bdaf2162b846 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializationHelper.scala @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.NotSerializableException +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.collection.mutable.HashMap +import scala.util.control.NonFatal + +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.Task +import org.apache.spark.serializer.SerializerInstance + +/** + * This enumeration defines variables use to standardize debugging output + */ +private[spark] object SerializationState extends Enumeration { + type SerializationState = String + val Failed = "Failed to serialize parent." + val FailedDeps = "Failed to serialize dependencies." + val Success = "Success" +} + +private[spark] case class RDDTrace (rdd : RDD[_], + depth : Int, + result : SerializationHelper.SerializedRef) + +/** + * This class is designed to encapsulate some utilities to facilitate debugging serialization + * problems in the DAGScheduler and the TaskSetManager. See SPARK-3694. + */ +private[spark] object SerializationHelper { + type PathToRef = mutable.LinkedList[AnyRef] + type BrokenRef = (AnyRef, PathToRef) + type SerializedRef = Either[String, ByteBuffer] + + /** + * Check whether a reference is serializable. + * + * If any dependency of an a reference is un-serializable, a NotSerializableException will be + * thrown and then we can execute a serialization trace to identify the problem reference. + * + * The stack trace is returned in the Left side + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param ref - The top-level reference that we are attempting to serialize + * @return SerializedRef - If serialization is successful, return success, else + * return a String, which clarifies why things failed. + */ + def tryToSerialize(closureSerializer: SerializerInstance, + ref: AnyRef): SerializedRef = { + val result: SerializedRef = try { + Right(closureSerializer.serialize(ref)) + } catch { + case e: NotSerializableException => Left(getSerializationTrace(closureSerializer, ref)) + case NonFatal(e) => Left(getSerializationTrace(closureSerializer, ref)) + } + + result + } + + /** + * Check whether the serialization of the RDD or its dependencies was successful. + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param serialized - Results of attempting to serialize the rdd and its dependencies + * @return the serialized parent rdd if successful + * @throws java.io.NotSerializableException if rdd or its dependencies didn't serialize + */ + @throws(classOf[NotSerializableException]) + def tryToSerializeRddAndDeps(closureSerializer: SerializerInstance, + serialized : Array[RDDTrace]) : ByteBuffer = { + if (serialized.filter(trace => trace.result.isLeft).length > 0) { + throw new NotSerializableException("Failed to serialize dependencies.") + } + + // If we get here we know that this (the serialization of the parent rdd) was successful + serialized(0).result.right.get + } + + /** + * When debugging RDD serialization failures generate the trace differently. + * This is because when RDDs have nested un-serializable dependencies the reference graph becomes + * much harder to trace. Thus, generate a reference trace only for the un-serializable RDDs and + * their parents - not their ancestors. We can still see ancestry from the initially logged + * output. + + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param rdd - Rdd to attempt to serialize + */ + def tryToSerializeRdd(closureSerializer: SerializerInstance, + rdd: RDD[_]): SerializedRef = { + val serialized: Array[RDDTrace] = tryToSerializeRddAndDeps(closureSerializer, rdd) + + def handleException: Left[String, Nothing] = { + var failedString = "" + + // For convenience, first output a trace by depth of whether each dependency serialized + serialized.map { + case trace: RDDTrace => + val out = ("Depth " + trace.depth + ": " + + trace.rdd.toString + " - " + + trace.result.fold(l => l, r => SerializationState.Success)) + failedString += out + "\n" + } + + // Next, print a specific reference trace for each un-serializable RDD + serialized.map { + case trace: RDDTrace => + trace.result.fold(l => { + failedString += ("" + getSerializationTrace(closureSerializer, trace.rdd) + "\n") + }, r => {}) + } + Left(failedString) + } + + val result: SerializedRef = try { + Right(tryToSerializeRddAndDeps(closureSerializer,serialized)) + } catch { + case e: NotSerializableException => handleException + case NonFatal(e) => handleException + } + + result + } + + /** + * Attempt to serialize an rdd and its dependencies and on a per-rdd basis provide a result. + * + * The reason we want to do this is because for RDDs with nested un-serializable dependencies, it + * becomes challenging to read the serialization trace to identify failures. This approach lets us + * only print out the failed RDDs specifically. + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param rdd - Rdd to attempt to serialize + * @return new Array[RDDTrace] where each entry represents one of the RDDs in the tree of the + * parent RDDs dependencies. Each entry provides a reference to the rdd, its depth in the + * tree and the result of serialization. + */ + def tryToSerializeRddAndDeps(closureSerializer: SerializerInstance, + rdd: RDD[_]): Array[RDDTrace] = { + // Walk the RDD so that we can display a trace on a per-dependency basis + val traversal: Array[(RDD[_], Int)] = RDDWalker.walk(rdd) + + def handleException(curRdd: RDD[_]): Left[String, Nothing] = { + Left(handleFailedRdd(closureSerializer, curRdd)) + } + + // Attempt to serialize each dependency of the RDD (track depth information to facilitate + // debugging). + val serialized = traversal.map { + case (curRdd, depth) => + val result: SerializedRef = try { + Right(closureSerializer.serialize(curRdd)) + } catch { + case e: NotSerializableException => handleException(curRdd) + case NonFatal(e) => handleException(curRdd) + } + + RDDTrace(curRdd, depth, result) + } + + serialized + } + + /** + * Helper function to separate an un-serializable parent rdd from un-serializable dependencies + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param rdd - Rdd to attempt to serialize + * @return String - Return a String (SerializationFailure), which clarifies why the serialization + * failed. + */ + private def handleFailedRdd(closureSerializer: SerializerInstance, + rdd: RDD[_]): String = { + if (rdd.dependencies.nonEmpty) { + try { + rdd.dependencies.foreach(dep => closureSerializer.serialize(dep: AnyRef)) + + // By default return a parent failure since we know that the parent already failed + SerializationState.Failed + } catch { + // If instead, however, the dependencies ALSO fail to serialize then the subsequent stage + // of evaluation will help identify which of the dependencies has failed + case e: NotSerializableException => SerializationState.FailedDeps + case NonFatal(e) => SerializationState.FailedDeps + } + } + else { + SerializationState.Failed + } + } + + /** + * When an RDD is identified as un-serializable, use the generic ObjectWalker class to debug + * the references of that RDD and generate a set of paths to broken references + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param ref - The reference known to be un-serializable + * @return a Set of (AnyRef, LinkedList) - a tuple of the un-serialiazble reference and the + * path to that reference + */ + private def getPathsToBrokenRefs(closureSerializer: SerializerInstance, + ref: AnyRef): mutable.Set[BrokenRef] = { + val refGraph: mutable.LinkedList[AnyRef] = ObjectWalker.buildRefGraph(ref) + val brokenRefs = mutable.Set[BrokenRef]() + + refGraph.foreach { + case ref: AnyRef => + try { + closureSerializer.serialize(ref) + } catch { + case e: NotSerializableException => brokenRefs.add(ref, ObjectWalker.buildRefGraph(ref)) + case NonFatal(e) => brokenRefs.add(ref, ObjectWalker.buildRefGraph(ref)) + } + } + + brokenRefs + } + + /** + * Returns nicely formatted text representing the trace of the failed serialization + * + * @param closureSerializer - An instance of a serializer (single-threaded) that will be used + * @param ref - The top-level reference that we are attempting to serialize + * @return + */ + def getSerializationTrace(closureSerializer: SerializerInstance, + ref: AnyRef): String = { + var trace = "Un-serializable reference trace for " + ref.toString + ":\n" + trace += brokenRefsToString(getPathsToBrokenRefs(closureSerializer, ref)) + trace + } + + def refString(ref: AnyRef): String = { + val refCode = System.identityHashCode(ref) + "Ref (" + ref.toString + ")" + } + + /** + * Given a set of reference and the paths to those references (as a dependency tree), return + * a cleanly formatted string showing these paths. + * + * @param brokenRefPath - a tuple of the un-serialiazble reference and the path to that reference + */ + private def brokenRefsToString(brokenRefPath: mutable.Set[BrokenRef]): String = { + var trace = "**********************\n" + + brokenRefPath.foreach(s => trace += brokenRefToString(s) + "**********************\n") + trace + } + + /** + * Given a reference and a path to that reference (as a dependency tree), return a cleanly + * formatted string showing this path. + * @param brokenRefPath - a tuple of the un-serialiazble reference and the path to that reference + */ + private def brokenRefToString(brokenRefPath: (AnyRef, mutable.LinkedList[AnyRef])): String = { + val ref = brokenRefPath._1 + val path = brokenRefPath._2 + + var trace = ref + ":\n" + path.foreach(s => { + trace += "--- " + refString(s) + "\n" + }) + + trace + } + + /** + * Provide a string representation of the task and its dependencies (in terms of added files + * and jars that must be shipped with the task) for debugging purposes. + * @param task - The task to serialize + * @param addedFiles - The file dependencies + * @param addedJars - The JAR dependencies + * @return String - The task and dependencies as a string + */ + private def taskDebugString(task: Task[_], + addedFiles: HashMap[String, Long], + addedJars: HashMap[String, Long]): String = { + val taskStr = "[" + task.toString + "] \n" + val strPrefix = s"-- " + val nl = s"\n" + val fileTitle = s"File dependencies:$nl" + val jarTitle = s"Jar dependencies:$nl" + + val fileStr = addedFiles.keys.map(file => s"$strPrefix $file").reduce(_ + nl + _) + nl + val jarStr = addedJars.keys.map(jar => s"$strPrefix $jar").reduce(_ + nl + _) + nl + + s"$taskStr $nl $fileTitle $fileStr $jarTitle $jarStr" + } +} + diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index d30eb10bbe947..90e6c0f7a3a8c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,7 +30,8 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.CallSite +import org.apache.spark.util.{CallSite, RDDTrace, SerializationState} + import org.apache.spark.executor.TaskMetrics class BuggyDAGEventProcessActor extends Actor { @@ -244,6 +245,111 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runEvent(JobCancelled(jobId)) } + test("Serialization trace for un-serializable task") { + val unserializableRdd = new MyRDD(sc, 1, Nil) { + class UnserializableClass + val unserializable = new UnserializableClass + } + + val trace : Array[RDDTrace] = scheduler.tryToSerializeRddDeps(unserializableRdd) + + assert(trace.length == 1) + assert(trace(0).result.isLeft) //Failed to serialize + } + + test("Serialization trace for un-serializable task with serializable dependencies") { + // The trace should show which nested dependency is unserializable + + val baseRdd = new MyRDD(sc, 1, Nil) + val midRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(midRdd))) { + class UnserializableClass + val unserializable = new UnserializableClass + } + + // Generate results array as (Success/Failure (Boolean) , ResultString (String)) + val results = Array((false, SerializationState.Failed), + (true, SerializationState.Success), + (true, SerializationState.Success)) + + val zipped : Array[(RDDTrace, Int)] = scheduler.tryToSerializeRddDeps(finalRdd).zipWithIndex + zipped.map { + case (trace : RDDTrace, idx : Int) => + trace.result match { + case Right(r) => assert(results(idx)._1) //Success + case Left(l) => assert(results(idx)._2.equals(l)) //Match failure strings + } + } + } + + test("Serialization trace for serializable task and nested unserializable dependency") { + // The trace should show which nested dependency is unserializable + + val baseRdd = new MyRDD(sc, 1, Nil) { + class UnserializableClass + val unserializable = new UnserializableClass + } + + val midRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(midRdd))) + + // Generate results array as (Success/Failure (Boolean) , ResultString (String)) + val results = Array((false, SerializationState.FailedDeps), + (false, SerializationState.FailedDeps), + (false, SerializationState.Failed)) + + val zipped : Array[(RDDTrace, Int)] = scheduler.tryToSerializeRddDeps(finalRdd).zipWithIndex + zipped.map { + case (trace : RDDTrace, idx : Int) => + trace.result match { + case Right(r) => assert(results(idx)._1) //Success + case Left(l) => assert(results(idx)._2.equals(l)) //Match failure strings + } + } + + } + + test("Serialization trace for serializable task with sandwiched unserializable dependency") { + // The trace should show which nested dependency is unserializable + + val baseRdd = new MyRDD(sc, 1, Nil) + val midRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) { + class UnserializableClass + val unserializable = new UnserializableClass + } + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(midRdd))) + + // Generate results array as (Success/Failure (Boolean) , ResultString (String)) + val results = Array((false, SerializationState.FailedDeps), + (false, SerializationState.Failed), + (true, SerializationState.Success)) + + val zipped : Array[(RDDTrace, Int)] = scheduler.tryToSerializeRddDeps(finalRdd).zipWithIndex + zipped.map { + case (trace : RDDTrace, idx : Int) => + trace.result match { + case Right(r) => assert(results(idx)._1) //Success + case Left(l) => assert(results(idx)._2.equals(l)) //Match failure strings + } + } + } + + test("Serialization trace for serializable task and nested dependencies") { + // Because serialization also attempts to serialize dependencies, attempting to + // serialize the serializable "finalRdd" should fail and the trace should show all its + // dependencies as being unserializable. + + val baseRdd = new MyRDD(sc, 1, Nil) + val midRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) + val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(midRdd))) + + val zipped : Array[(RDDTrace, Int)] = scheduler.tryToSerializeRddDeps(finalRdd).zipWithIndex + + zipped.map { + case (trace : RDDTrace, idx : Int) => assert(trace.result.isRight) + } + } + test("[SPARK-3353] parent stage should have lower stage id") { sparkListener.stageByOrderOfExecution.clear() sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() @@ -539,7 +645,6 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F assertDataStructuresEmpty } - /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * diff --git a/docs/configuration.md b/docs/configuration.md index f292bfbb7dcd6..b71713db51f98 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -557,6 +557,14 @@ Apart from these, the following properties are also available, and may be useful + spark.serializer.debug + false + + To view the dependency graph for an RDD or the file dependencies for task set this option + to true. Doing so will display a text based graph of these dependencies along with a + serialization trace that identifies which components of an RDD failed to serialize. + + spark.kryo.referenceTracking true