From 7d769b8d87197211f6bd7c91335f9712fc4da949 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 10 Nov 2020 13:07:37 -0500 Subject: [PATCH] Allow actors to call actors on the same worker (#4225) --- distributed/actor.py | 32 +++++++++++++++++++++++++++++--- distributed/tests/test_actor.py | 33 ++++++++++++++++++++++++++++++++- distributed/worker.py | 1 + 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index dc49571d1d..54e9000bda 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -1,11 +1,12 @@ import asyncio import functools +from inspect import iscoroutinefunction import threading from queue import Queue from .client import Future, default_client from .protocol import to_serialize -from .utils import sync +from .utils import thread_state, sync from .utils_comm import WrappedKey from .worker import get_worker @@ -118,13 +119,33 @@ def __dir__(self): return sorted(o) def __getattr__(self, key): - attr = getattr(self._cls, key) if self._future and self._future.status not in ("finished", "pending"): raise ValueError( "Worker holding Actor was lost. Status: " + self._future.status ) + if ( + self._worker + and self._worker.address == self._address + and getattr(thread_state, "actor", False) + ): + # actor calls actor on same worker + actor = self._worker.actors[self.key] + attr = getattr(actor, key) + + if iscoroutinefunction(attr): + return attr + + elif callable(attr): + return lambda *args, **kwargs: ActorFuture( + None, None, result=attr(*args, **kwargs) + ) + else: + return attr + + attr = getattr(self._cls, key) + if callable(attr): @functools.wraps(attr) @@ -206,9 +227,14 @@ class ActorFuture: Actor """ - def __init__(self, q, io_loop): + def __init__(self, q, io_loop, result=None): self.q = q self.io_loop = io_loop + if result: + self._cached_result = result + + def __await__(self): + return self.result() def result(self, timeout=None): try: diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 89233eaca2..26421d0385 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -6,7 +6,7 @@ import dask from distributed import Actor, ActorFuture, Client, Future, wait, Nanny -from distributed.utils_test import gen_cluster +from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import client, cluster_fixture, loop # noqa: F401 from distributed.metrics import time @@ -21,11 +21,25 @@ def increment(self): self.n += 1 return self.n + async def ainc(self): + self.n += 1 + return self.n + def add(self, x): self.n += x return self.n +class UsesCounter: + # An actor whose method argument is another actor + + def do_inc(self, ac): + return ac.increment().result() + + async def ado_inc(self, ac): + return await ac.ainc() + + class List: L = [] @@ -550,3 +564,20 @@ async def wait(self): await waiter.set() await c.gather(futures) + + +def test_one_thread_deadlock(): + with cluster(nworkers=2) as (cl, w): + client = Client(cl["address"]) + ac = client.submit(Counter, actor=True).result() + ac2 = client.submit(UsesCounter, actor=True, workers=[ac._address]).result() + + assert ac2.do_inc(ac).result() == 1 + + +@gen_cluster(client=True) +async def test_async_deadlock(client, s, a, b): + ac = await client.submit(Counter, actor=True) + ac2 = await client.submit(UsesCounter, actor=True, workers=[ac._address]) + + assert (await ac2.ado_inc(ac)) == 1 diff --git a/distributed/worker.py b/distributed/worker.py index b9bc9b6337..d247fb4211 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3416,6 +3416,7 @@ def apply_function_actor( thread_state.execution_state = execution_state thread_state.key = key + thread_state.actor = True result = function(*args, **kwargs)