Skip to content

Commit

Permalink
[SPARK-6980] Resolved conflicts after master merge
Browse files Browse the repository at this point in the history
  • Loading branch information
BryanCutler committed Jun 3, 2015
1 parent c07d05c commit 235919b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import com.google.common.util.concurrent.MoreExecutors

import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils}

/**
* A RpcEnv implementation based on Akka.
Expand Down Expand Up @@ -293,8 +295,8 @@ private[akka] class AkkaRpcEndpointRef(
}

override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
import scala.concurrent.ExecutionContext.Implicits.global
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Expand All @@ -304,7 +306,7 @@ private[akka] class AkkaRpcEndpointRef(
}
case AkkaFailure(e) =>
Future.failed(e)
}.mapTo[T]
}(ThreadUtils.sameThread).mapTo[T]
}

override def toString: String = s"${getClass.getSimpleName}($actorRef)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.storage

import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import scala.collection.Iterable
import scala.collection.generic.CanBuildFrom
import scala.concurrent.{Await, Future}

import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.RpcUtils
import org.apache.spark.util.{ThreadUtils, RpcUtils}

private[spark]
class BlockManagerMaster(
Expand Down Expand Up @@ -102,8 +103,8 @@ class BlockManagerMaster(
val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
}
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
timeout.awaitResult(future)
}
Expand All @@ -114,8 +115,8 @@ class BlockManagerMaster(
val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
}
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
timeout.awaitResult(future)
}
Expand All @@ -128,8 +129,8 @@ class BlockManagerMaster(
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove broadcast $broadcastId" +
s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}")
}
s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e)
}(ThreadUtils.sameThread)
if (blocking) {
timeout.awaitResult(future)
}
Expand Down Expand Up @@ -169,11 +170,17 @@ class BlockManagerMaster(
val response = driverEndpoint.
askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = timeout.awaitResult(Future.sequence(futures))
if (result == null) {
implicit val sameThread = ThreadUtils.sameThread
val cbf =
implicitly[
CanBuildFrom[Iterable[Future[Option[BlockStatus]]],
Option[BlockStatus],
Iterable[Option[BlockStatus]]]]
val blockStatus = timeout.awaitResult(
Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread))
if (blockStatus == null) {
throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
}
val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
status.map { s => (blockManagerId, s) }
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,21 @@

package org.apache.spark.rpc.akka

import java.util.concurrent.TimeoutException

import scala.concurrent.Await
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.{Success, Failure}
import scala.language.postfixOps

import akka.actor.{ActorSystem, Actor, ActorRef, Props, Address}
import akka.pattern.ask

import org.apache.spark.rpc._
import org.apache.spark.{SecurityManager, SparkConf}


class AkkaRpcEnvSuite extends RpcEnvSuite {

override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
Expand Down Expand Up @@ -47,4 +59,72 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
}
}

test("Future failure with RpcTimeout") {

class EchoActor extends Actor {
def receive: Receive = {
case msg =>
Thread.sleep(500)
sender() ! msg
}
}

val system = ActorSystem("EchoSystem")
val echoActor = system.actorOf(Props(new EchoActor), name = "echoA")

val timeout = new RpcTimeout(50 millis, "spark.rpc.short.timeout")

val fut = echoActor.ask("hello")(1000 millis).mapTo[String].recover {
case te: TimeoutException => throw timeout.amend(te)
}

fut.onFailure {
case te: TimeoutException => println("failed with timeout exception")
}

fut.onComplete {
case Success(str) => println("future success")
case Failure(ex) => println("future failure")
}

println("sleeping")
Thread.sleep(50)
println("Future complete: " + fut.isCompleted.toString() + ", " + fut.value.toString())

println("Caught TimeoutException: " +
intercept[TimeoutException] {
//timeout.awaitResult(fut) // prints RpcTimeout description twice
Await.result(fut, 10 millis)
}.getMessage()
)

/*
val ref = env.setupEndpoint("test_future", new RpcEndpoint {
override val rpcEnv = env
override def receive = {
case _ =>
}
})
val conf = new SparkConf()
val newRpcEnv = new AkkaRpcEnvFactory().create(
RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
try {
val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_future")
val akkaActorRef = newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef
val timeout = new RpcTimeout(1 millis, "spark.rpc.short.timeout")
val fut = akkaActorRef.ask("hello")(timeout.duration).mapTo[String]
Thread.sleep(500)
println("Future complete: " + fut.isCompleted.toString() + ", " + fut.value.toString())
} finally {
newRpcEnv.shutdown()
}
*/


}

}

0 comments on commit 235919b

Please sign in to comment.