diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 9a42487bb37aa..9c4b8cdc646b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -23,13 +23,12 @@ import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{SecurityManager, SparkConf, Logging} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Manages the execution of one executor process. @@ -60,6 +59,9 @@ private[deploy] class ExecutorRunner( private var stdoutAppender: FileAppender = null private var stderrAppender: FileAppender = null + // Timeout to wait for when trying to terminate an executor. + private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000 + // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. private var shutdownHook: AnyRef = null @@ -94,8 +96,11 @@ private[deploy] class ExecutorRunner( if (stderrAppender != null) { stderrAppender.stop() } - process.destroy() - exitCode = Some(process.waitFor()) + exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS) + if (exitCode.isEmpty) { + logWarning("Failed to terminate process: " + process + + ". This process will likely be orphaned.") + } } try { worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b8ca6b07e4198..0c1f9c1ae2496 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1698,6 +1698,30 @@ private[spark] object Utils extends Logging { new File(path).getName } + /** + * Terminates a process waiting for at most the specified duration. Returns whether + * the process terminated. + */ + def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { + try { + // Java8 added a new API which will more forcibly kill the process. Use that if available. + val destroyMethod = process.getClass().getMethod("destroyForcibly"); + destroyMethod.setAccessible(true) + destroyMethod.invoke(process) + } catch { + case NonFatal(e) => + if (!e.isInstanceOf[NoSuchMethodException]) { + logWarning("Exception when attempting to kill process", e) + } + process.destroy() + } + if (waitForProcess(process, timeoutMs)) { + Option(process.exitValue()) + } else { + None + } + } + /** * Wait for a process to terminate for at most the specified duration. * Return whether the process actually terminated after the given timeout. diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index fdb51d440eff6..7de995af512db 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,26 +17,24 @@ package org.apache.spark.util -import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols -import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - +import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path - import org.apache.spark.network.util.ByteUnit -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -745,4 +743,77 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") } + + test("Kill process") { + // Verify that we can terminate a process even if it is in a bad state. This is only run + // on UNIX since it does some OS specific things to verify the correct behavior. + if (SystemUtils.IS_OS_UNIX) { + def getPid(p: Process): Int = { + val f = p.getClass().getDeclaredField("pid") + f.setAccessible(true) + f.get(p).asInstanceOf[Int] + } + + def pidExists(pid: Int): Boolean = { + val p = Runtime.getRuntime.exec(s"kill -0 $pid") + p.waitFor() + p.exitValue() == 0 + } + + def signal(pid: Int, s: String): Unit = { + val p = Runtime.getRuntime.exec(s"kill -$s $pid") + p.waitFor() + } + + // Start up a process that runs 'sleep 10'. Terminate the process and assert it takes + // less time and the process is no longer there. + val startTimeMs = System.currentTimeMillis() + val process = new ProcessBuilder("sleep", "10").start() + val pid = getPid(process) + try { + assert(pidExists(pid)) + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val durationMs = System.currentTimeMillis() - startTimeMs + assert(durationMs < 5000) + assert(!pidExists(pid)) + } finally { + // Forcibly kill the test process just in case. + signal(pid, "SIGKILL") + } + + val v: String = System.getProperty("java.version") + if (v >= "1.8.0") { + // Java8 added a way to forcibly terminate a process. We'll make sure that works by + // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On + // older versions of java, this will *not* terminate. + val file = File.createTempFile("temp-file-name", ".tmp") + val cmd = + s""" + |#!/bin/bash + |trap "" SIGTERM + |sleep 10 + """.stripMargin + Files.write(cmd.getBytes(), file) + file.getAbsoluteFile.setExecutable(true) + + val process = new ProcessBuilder(file.getAbsolutePath).start() + val pid = getPid(process) + assert(pidExists(pid)) + try { + signal(pid, "SIGSTOP") + val start = System.currentTimeMillis() + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val duration = System.currentTimeMillis() - start + assert(duration < 5000) + assert(!pidExists(pid)) + } finally { + signal(pid, "SIGKILL") + } + } + } + } }