Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Support kwargs to deployment constructor #19023

Merged
merged 8 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def _wait_for_goal(self,
def deploy(self,
name: str,
backend_def: Union[Callable, Type[Callable], str],
*init_args: Any,
init_args: Tuple[Any],
init_kwargs: Dict[Any],
ray_actor_options: Optional[Dict] = None,
config: Optional[Union[BackendConfig, Dict[str, Any]]] = None,
version: Optional[str] = None,
Expand All @@ -212,7 +213,10 @@ def deploy(self,
del ray_actor_options["runtime_env"]["working_dir"]

replica_config = ReplicaConfig(
backend_def, *init_args, ray_actor_options=ray_actor_options)
backend_def,
init_args,
init_kwargs,
ray_actor_options=ray_actor_options)

if isinstance(config, dict):
backend_config = BackendConfig.parse_obj(config)
Expand Down Expand Up @@ -601,6 +605,7 @@ def __init__(self,
version: Optional[str] = None,
prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Tuple[Any]] = None,
route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None,
_internal=False) -> None:
Expand All @@ -626,6 +631,8 @@ def __init__(self,
raise TypeError("prev_version must be a string.")
if not (init_args is None or isinstance(init_args, tuple)):
raise TypeError("init_args must be a tuple.")
if not (init_kwargs is None or isinstance(init_kwargs, dict)):
raise TypeError("init_kwargs must be a dict.")
if route_prefix is not None:
if not isinstance(route_prefix, str):
raise TypeError("route_prefix must be a string.")
Expand All @@ -642,13 +649,16 @@ def __init__(self,

if init_args is None:
init_args = ()
if init_kwargs is None:
init_kwargs = {}

self._func_or_class = func_or_class
self._name = name
self._version = version
self._prev_version = prev_version
self._config = config
self._init_args = init_args
self._init_kwargs = init_kwargs
self._route_prefix = route_prefix
self._ray_actor_options = ray_actor_options

Expand Down Expand Up @@ -706,7 +716,12 @@ def ray_actor_options(self) -> Optional[Dict]:

@property
def init_args(self) -> Tuple[Any]:
"""Arguments passed to the underlying class's constructor."""
"""Positional args passed to the underlying class's constructor."""
return self._init_args

@property
def init_kwargs(self) -> Tuple[Any]:
"""Keyword args passed to the underlying class's constructor."""
return self._init_args

@property
Expand All @@ -720,20 +735,25 @@ def __call__(self):
"Use `deployment.deploy() instead.`")

@PublicAPI
def deploy(self, *init_args, _blocking=True):
def deploy(self, *init_args, _blocking=True, **init_kwargs):
Copy link
Contributor

@simon-mo simon-mo Sep 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should blocking move to options?

"""Deploy or update this deployment.

Args:
init_args (optional): args to pass to the class __init__
method. Not valid if this deployment wraps a function.
init_kwargs (optional): kwargs to pass to the class __init__
method. Not valid if this deployment wraps a function.
"""
if len(init_args) == 0 and self._init_args is not None:
init_args = self._init_args
if len(init_kwargs) == 0 and self._init_kwargs is not None:
init_kwargs = self._init_kwargs

return _get_global_client().deploy(
self._name,
self._func_or_class,
*init_args,
init_args,
init_kwargs,
ray_actor_options=self._ray_actor_options,
config=self._config,
version=self._version,
Expand Down Expand Up @@ -772,6 +792,7 @@ def options(
version: Optional[str] = None,
prev_version: Optional[str] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any]] = None,
route_prefix: Optional[str] = None,
num_replicas: Optional[int] = None,
ray_actor_options: Optional[Dict] = None,
Expand Down Expand Up @@ -803,6 +824,9 @@ def options(
if init_args is None:
init_args = self._init_args

if init_kwargs is None:
init_kwargs = self._init_kwargs

if route_prefix is None:
if self._route_prefix == f"/{self._name}":
route_prefix = None
Expand All @@ -819,6 +843,7 @@ def options(
version=version,
prev_version=prev_version,
init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix,
ray_actor_options=ray_actor_options,
_internal=True,
Expand All @@ -830,6 +855,7 @@ def __eq__(self, other):
self._version == other._version,
self._config == other._config,
self._init_args == other._init_args,
self._init_kwargs == other._init_kwargs,
self._route_prefix == other._route_prefix,
self._ray_actor_options == self._ray_actor_options,
])
Expand Down Expand Up @@ -858,6 +884,7 @@ def deployment(name: Optional[str] = None,
prev_version: Optional[str] = None,
num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any]] = None,
ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None,
max_concurrent_queries: Optional[int] = None,
Expand All @@ -874,6 +901,7 @@ def deployment(
prev_version: Optional[str] = None,
num_replicas: Optional[int] = None,
init_args: Optional[Tuple[Any]] = None,
init_kwargs: Optional[Dict[Any]] = None,
route_prefix: Optional[str] = None,
ray_actor_options: Optional[Dict] = None,
user_config: Optional[Any] = None,
Expand All @@ -897,7 +925,10 @@ def deployment(
not check the existing deployment's version.
num_replicas (Optional[int]): The number of processes to start up that
will handle requests to this deployment. Defaults to 1.
init_args (Optional[Tuple]): Arguments to be passed to the class
init_args (Optional[Tuple]): Positional args to be passed to the class
constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment.
init_kwargs (Optional[Dict]): Keyword args to be passed to the class
constructor when starting up deployment replicas. These can also be
passed when you call `.deploy()` on the returned Deployment.
route_prefix (Optional[str]): Requests to paths under this HTTP path
Expand Down Expand Up @@ -955,6 +986,7 @@ def decorator(_func_or_class):
version=version,
prev_version=prev_version,
init_args=init_args,
init_kwargs=init_kwargs,
route_prefix=route_prefix,
ray_actor_options=ray_actor_options,
_internal=True,
Expand Down Expand Up @@ -996,6 +1028,7 @@ def get_deployment(name: str) -> Deployment:
backend_info.backend_config,
version=backend_info.version,
init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True,
Expand All @@ -1019,6 +1052,7 @@ def list_deployments() -> Dict[str, Deployment]:
backend_info.backend_config,
version=backend_info.version,
init_args=backend_info.replica_config.init_args,
init_kwargs=backend_info.replica_config.init_kwargs,
route_prefix=route_prefix,
ray_actor_options=backend_info.replica_config.ray_actor_options,
_internal=True,
Expand Down
1 change: 1 addition & 0 deletions python/ray/serve/backend_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def start(self, backend_info: BackendInfo, version: BackendVersion):
**backend_info.replica_config.ray_actor_options).remote(
self.backend_tag, self.replica_tag,
backend_info.replica_config.init_args,
backend_info.replica_config.init_kwargs,
backend_info.backend_config.to_proto_bytes(), version,
self._controller_name, self._detached)

Expand Down
5 changes: 3 additions & 2 deletions python/ray/serve/backend_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_backend_replica(name: str, serialized_backend_def: bytes):
# TODO(architkulkarni): Add type hints after upgrading cloudpickle
class RayServeWrappedReplica(object):
async def __init__(self, backend_tag, replica_tag, init_args,
backend_config_proto_bytes: bytes,
init_kwargs, backend_config_proto_bytes: bytes,
version: BackendVersion, controller_name: str,
detached: bool):
backend = cloudpickle.loads(serialized_backend_def)
Expand Down Expand Up @@ -72,7 +72,8 @@ async def __init__(self, backend_tag, replica_tag, init_args,
# This allows backends to define an async __init__ method
# (required for FastAPI backend definition).
_callable = backend.__new__(backend)
await sync_to_async(_callable.__init__)(*init_args)
await sync_to_async(_callable.__init__)(*init_args,
**init_kwargs)
# Setting the context again to update the servable_object.
ray.serve.api._set_internal_replica_context(
backend_tag,
Expand Down
17 changes: 13 additions & 4 deletions python/ray/serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def from_proto_bytes(cls, proto_bytes: bytes):


class ReplicaConfig:
def __init__(self, backend_def, *init_args, ray_actor_options=None):
def __init__(self,
backend_def,
init_args,
init_kwargs,
ray_actor_options=None):
# Validate that backend_def is an import path, function, or class.
if isinstance(backend_def, str):
self.func_or_class_name = backend_def
Expand All @@ -131,6 +135,9 @@ def __init__(self, backend_def, *init_args, ray_actor_options=None):
if len(init_args) != 0:
raise ValueError(
"init_args not supported for function backend.")
if len(init_kwargs) != 0:
raise ValueError(
"init_kwargs not supported for function backend.")
elif inspect.isclass(backend_def):
self.func_or_class_name = backend_def.__name__
else:
Expand All @@ -140,6 +147,7 @@ def __init__(self, backend_def, *init_args, ray_actor_options=None):

self.serialized_backend_def = cloudpickle.dumps(backend_def)
self.init_args = init_args
self.init_kwargs = init_kwargs
if ray_actor_options is None:
self.ray_actor_options = {}
else:
Expand All @@ -158,12 +166,13 @@ def _validate(self):
raise TypeError("ray_actor_options must be a dictionary.")
elif "lifetime" in self.ray_actor_options:
raise ValueError(
"Specifying lifetime in init_args is not allowed.")
"Specifying lifetime in ray_actor_options is not allowed.")
elif "name" in self.ray_actor_options:
raise ValueError("Specifying name in init_args is not allowed.")
raise ValueError(
"Specifying name in ray_actor_options is not allowed.")
elif "max_restarts" in self.ray_actor_options:
raise ValueError("Specifying max_restarts in "
"init_args is not allowed.")
"ray_actor_options is not allowed.")
else:
# Ray defaults to zero CPUs for placement, we default to one here.
if "num_cpus" not in self.ray_actor_options:
Expand Down
53 changes: 53 additions & 0 deletions python/ray/serve/tests/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,59 @@ def check(*args):
check(10, 11, 12)


def test_init_kwargs(serve_instance):
with pytest.raises(TypeError):

@serve.deployment(init_kwargs=[1, 2, 3])
class BadInitArgs:
pass

@serve.deployment(init_kwargs={"1": 1, "2": 2})
class D:
def __init__(self, *, **kwargs):
self._kwargs = kwargs

def get_args(self, *args):
return self._kwargs

D.deploy()
handle = D.get_handle()

def check(kwargs):
assert ray.get(handle.get_args.remote()) == args

# Basic sanity check.
assert ray.get(handle.get_args.remote()) == {"a": 1, "b": 2}
check({"a": 1, "b": 2})

# Check passing args to `.deploy()`.
D.deploy(a=3, b=4)
check({"a": 3, "b": 4})

# Passing args to `.deploy()` shouldn't override those passed in decorator.
D.deploy()
check({"a": 1, "b": 2})

# Check setting with `.options()`.
new_D = D.options(init_kwargs={"c": 8, "d": 10})
new_D.deploy()
check({"c": 8, "d": 10})

# Should not have changed old deployment object.
D.deploy()
check({"a": 1, "b": 2})

# Check that args are only updated on version change.
D.options(version="1").deploy()
check({"a": 1, "b": 2})

D.options(version="1").deploy(c=10, d=11)
check({"a": 1, "b": 2})

D.options(version="2").deploy(c=10, d=11)
check({"c": 10, "d": 11})


def test_input_validation():
name = "test"

Expand Down
30 changes: 30 additions & 0 deletions python/ray/serve/tests/test_get_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,36 @@ def __call__(self, *arg):
assert pid3 != pid2


def test_init_kwargs(serve_instance):
name = "test"

@serve.deployment(name=name)
class D:
def __init__(self, *, val=val):
self._val = val

def __call__(self, *arg):
return self._val, os.getpid()

D.deploy(val="1")
val1, pid1 = ray.get(D.get_handle().remote())
assert val1 == "1"

del D

D2 = serve.get_deployment(name=name)
D2.deploy()
val2, pid2 = ray.get(D2.get_handle().remote())
assert val2 == "1"
assert pid2 != pid1

D2 = serve.get_deployment(name=name)
D2.deploy(val="2")
val3, pid3 = ray.get(D2.get_handle().remote())
assert val3 == "2"
assert pid3 != pid2


def test_scale_replicas(serve_instance):
name = "test"

Expand Down