From df85d310958997d3465823bab5e50e4e481b3683 Mon Sep 17 00:00:00 2001 From: Archit Kulkarni Date: Thu, 17 Feb 2022 17:29:44 -0800 Subject: [PATCH] [Serve] Make handle serializable (#22473) --- python/ray/serve/handle.py | 9 +++++++-- python/ray/serve/tests/test_handle.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 2fa687acc079..4b2f6e28893d 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -169,6 +169,11 @@ async def remote(self, *args, **kwargs): def __repr__(self): return f"{self.__class__.__name__}" f"(deployment='{self.deployment_name}')" + @classmethod + def _deserialize(cls, kwargs): + """Required for this class's __reduce__ method to be picklable.""" + return cls(**kwargs) + def __reduce__(self): serialized_data = { "controller_handle": self.controller_handle, @@ -176,7 +181,7 @@ def __reduce__(self): "handle_options": self.handle_options, "_internal_pickled_http_request": self._pickled_http_request, } - return lambda kwargs: RayServeHandle(**kwargs), (serialized_data,) + return RayServeHandle._deserialize, (serialized_data,) def __getattr__(self, name): return self.options(method_name=name) @@ -228,4 +233,4 @@ def __reduce__(self): "handle_options": self.handle_options, "_internal_pickled_http_request": self._pickled_http_request, } - return lambda kwargs: RayServeSyncHandle(**kwargs), (serialized_data,) + return RayServeSyncHandle._deserialize, (serialized_data,) diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index a1bcf816686b..aaaf521b1e49 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -46,6 +46,31 @@ def task(handle): assert ray.get(result_ref) == "hello" +def test_handle_serializable_in_deployment_init(serve_instance): + """Test that a handle can be passed into a constructor (#22110)""" + + @serve.deployment + class RayServer1: + def __init__(self): + pass + + def __call__(self, *args): + return {"count": self.count} + + @serve.deployment + class RayServer2: + def __init__(self, handle): + self.handle = handle + + def __call__(self, *args): + return {"count": self.count} + + RayServer1.deploy() + for sync in [True, False]: + rs1_handle = RayServer1.get_handle(sync=sync) + RayServer2.deploy(rs1_handle) + + def test_sync_handle_in_thread(serve_instance): @serve.deployment def f():