Skip to content

Commit

Permalink
[Serve] Unified Controller API for Cross Language Client (#23004)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyang-my authored Apr 5, 2022
1 parent 434265e commit bdd3b9a
Show file tree
Hide file tree
Showing 10 changed files with 333 additions and 34 deletions.
47 changes: 38 additions & 9 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@
from ray.serve.controller import ServeController
from ray.serve.deployment import Deployment
from ray.serve.exceptions import RayServeException
from ray.serve.generated.serve_pb2 import (
DeploymentRoute,
DeploymentRouteList,
DeploymentStatusInfoList,
)
from ray.experimental.dag import DAGNode
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.http_util import ASGIHTTPSender, make_fastapi_class_based_view
Expand Down Expand Up @@ -283,7 +288,7 @@ def _wait_for_deployments_shutdown(self, timeout_s: int = 60):
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = ray.get(self._controller.get_deployment_statuses.remote())
statuses = self.get_deployment_statuses()
if len(statuses) == 0:
break
else:
Expand All @@ -308,7 +313,7 @@ def _wait_for_deployment_healthy(self, name: str, timeout_s: int = -1):
"""
start = time.time()
while time.time() - start < timeout_s or timeout_s < 0:
statuses = ray.get(self._controller.get_deployment_statuses.remote())
statuses = self.get_deployment_statuses()
try:
status = statuses[name]
except KeyError:
Expand Down Expand Up @@ -341,7 +346,7 @@ def _wait_for_deployment_deleted(self, name: str, timeout_s: int = 60):
"""
start = time.time()
while time.time() - start < timeout_s:
statuses = ray.get(self._controller.get_deployment_statuses.remote())
statuses = self.get_deployment_statuses()
if name not in statuses:
break
else:
Expand Down Expand Up @@ -435,15 +440,38 @@ def delete_deployments(self, names: Iterable[str], blocking: bool = True) -> Non

@_ensure_connected
def get_deployment_info(self, name: str) -> Tuple[DeploymentInfo, str]:
return ray.get(self._controller.get_deployment_info.remote(name))
deployment_route = DeploymentRoute.FromString(
ray.get(self._controller.get_deployment_info.remote(name))
)
return (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)

@_ensure_connected
def list_deployments(self) -> Dict[str, Tuple[DeploymentInfo, str]]:
return ray.get(self._controller.list_deployments.remote())
deployment_route_list = DeploymentRouteList.FromString(
ray.get(self._controller.list_deployments.remote())
)
return {
deployment_route.deployment_info.name: (
DeploymentInfo.from_proto(deployment_route.deployment_info),
deployment_route.route if deployment_route.route != "" else None,
)
for deployment_route in deployment_route_list.deployment_routes
}

@_ensure_connected
def get_deployment_statuses(self) -> Dict[str, DeploymentStatusInfo]:
return ray.get(self._controller.get_deployment_statuses.remote())
proto = DeploymentStatusInfoList.FromString(
ray.get(self._controller.get_deployment_statuses.remote())
)
return {
deployment_status_info.name: DeploymentStatusInfo.from_proto(
deployment_status_info
)
for deployment_status_info in proto.deployment_status_infos
}

@_ensure_connected
def get_handle(
Expand Down Expand Up @@ -582,6 +610,9 @@ def get_deploy_args(
else:
raise TypeError("config must be a DeploymentConfig or a dictionary.")

deployment_config.version = version
deployment_config.prev_version = prev_version

if (
deployment_config.autoscaling_config is not None
and deployment_config.max_concurrent_queries
Expand All @@ -596,9 +627,7 @@ def get_deploy_args(
controller_deploy_args = {
"name": name,
"deployment_config_proto_bytes": deployment_config.to_proto_bytes(),
"replica_config": replica_config,
"version": version,
"prev_version": prev_version,
"replica_config_proto_bytes": replica_config.to_proto_bytes(),
"route_prefix": route_prefix,
"deployer_job_id": ray.get_runtime_context().job_id,
}
Expand Down
57 changes: 57 additions & 0 deletions python/ray/serve/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from ray.actor import ActorHandle
from ray.serve.config import DeploymentConfig, ReplicaConfig
from ray.serve.autoscaling_policy import AutoscalingPolicy
from ray.serve.generated.serve_pb2 import (
DeploymentInfo as DeploymentInfoProto,
DeploymentStatusInfo as DeploymentStatusInfoProto,
DeploymentStatus as DeploymentStatusProto,
DeploymentLanguage,
)

EndpointTag = str
ReplicaTag = str
Expand All @@ -29,6 +35,16 @@ class DeploymentStatusInfo:
status: DeploymentStatus
message: str = ""

def to_proto(self):
return DeploymentStatusInfoProto(status=self.status, message=self.message)

@classmethod
def from_proto(cls, proto: DeploymentStatusInfoProto):
return cls(
status=DeploymentStatus(DeploymentStatusProto.Name(proto.status)),
message=proto.message,
)


class DeploymentInfo:
def __init__(
Expand Down Expand Up @@ -95,6 +111,47 @@ def actor_def(self):

return self._cached_actor_def

@classmethod
def from_proto(cls, proto: DeploymentInfoProto):
deployment_config = (
DeploymentConfig.from_proto(proto.deployment_config)
if proto.deployment_config
else None
)
data = {
"deployment_config": deployment_config,
"replica_config": ReplicaConfig.from_proto(
proto.replica_config,
deployment_config.deployment_language
if deployment_config
else DeploymentLanguage.PYTHON,
),
"start_time_ms": proto.start_time_ms,
"actor_name": proto.actor_name if proto.actor_name != "" else None,
"serialized_deployment_def": proto.serialized_deployment_def
if proto.serialized_deployment_def != b""
else None,
"version": proto.version if proto.version != "" else None,
"end_time_ms": proto.end_time_ms if proto.end_time_ms != 0 else None,
"deployer_job_id": ray.get_runtime_context().job_id,
}

return cls(**data)

def to_proto(self):
data = {
"start_time_ms": self.start_time_ms,
"actor_name": self.actor_name,
"serialized_deployment_def": self.serialized_deployment_def,
"version": self.version,
"end_time_ms": self.end_time_ms,
}
if self.deployment_config:
data["deployment_config"] = self.deployment_config.to_proto()
if self.replica_config:
data["replica_config"] = self.replica_config.to_proto()
return DeploymentInfoProto(**data)


@dataclass
class ReplicaName:
Expand Down
76 changes: 71 additions & 5 deletions python/ray/serve/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import pickle
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -27,7 +28,9 @@
DeploymentConfig as DeploymentConfigProto,
DeploymentLanguage,
AutoscalingConfig as AutoscalingConfigProto,
ReplicaConfig as ReplicaConfigProto,
)
from ray.serve.utils import ServeEncoder


class AutoscalingConfig(BaseModel):
Expand Down Expand Up @@ -129,6 +132,9 @@ class DeploymentConfig(BaseModel):
# the deploymnent use.
deployment_language: Any = DeploymentLanguage.PYTHON

version: Optional[str] = None
prev_version: Optional[str] = None

class Config:
validate_assignment = True
extra = "forbid"
Expand All @@ -144,19 +150,21 @@ def set_max_queries_by_mode(cls, v, values): # noqa 805
raise ValueError("max_concurrent_queries must be >= 0")
return v

def to_proto_bytes(self):
def to_proto(self):
data = self.dict()
if data.get("user_config"):
data["user_config"] = pickle.dumps(data["user_config"])
if data.get("autoscaling_config"):
data["autoscaling_config"] = AutoscalingConfigProto(
**data["autoscaling_config"]
)
return DeploymentConfigProto(**data).SerializeToString()
return DeploymentConfigProto(**data)

def to_proto_bytes(self):
return self.to_proto().SerializeToString()

@classmethod
def from_proto_bytes(cls, proto_bytes: bytes):
proto = DeploymentConfigProto.FromString(proto_bytes)
def from_proto(cls, proto: DeploymentConfigProto):
data = MessageToDict(
proto,
including_default_value_fields=True,
Expand All @@ -170,9 +178,19 @@ def from_proto_bytes(cls, proto_bytes: bytes):
data["user_config"] = None
if "autoscaling_config" in data:
data["autoscaling_config"] = AutoscalingConfig(**data["autoscaling_config"])

if "prev_version" in data:
if data["prev_version"] == "":
data["prev_version"] = None
if "version" in data:
if data["version"] == "":
data["version"] = None
return cls(**data)

@classmethod
def from_proto_bytes(cls, proto_bytes: bytes):
proto = DeploymentConfigProto.FromString(proto_bytes)
return cls.from_proto(proto)


class ReplicaConfig:
def __init__(
Expand Down Expand Up @@ -286,6 +304,54 @@ def _validate(self):
raise TypeError("resources in ray_actor_options must be a dictionary.")
self.resource_dict.update(custom_resources)

@classmethod
def from_proto(
cls, proto: ReplicaConfigProto, deployment_language: DeploymentLanguage
):
deployment_def = None
if proto.serialized_deployment_def != b"":
if deployment_language == DeploymentLanguage.PYTHON:
deployment_def = cloudpickle.loads(proto.serialized_deployment_def)
else:
# TODO use messagepack
deployment_def = cloudpickle.loads(proto.serialized_deployment_def)

init_args = pickle.loads(proto.init_args) if proto.init_args != b"" else None
init_kwargs = (
pickle.loads(proto.init_kwargs) if proto.init_kwargs != b"" else None
)
ray_actor_options = (
json.loads(proto.ray_actor_options)
if proto.ray_actor_options != ""
else None
)

return ReplicaConfig(deployment_def, init_args, init_kwargs, ray_actor_options)

@classmethod
def from_proto_bytes(
cls, proto_bytes: bytes, deployment_language: DeploymentLanguage
):
proto = ReplicaConfigProto.FromString(proto_bytes)
return cls.from_proto(proto, deployment_language)

def to_proto(self):
data = {
"serialized_deployment_def": self.serialized_deployment_def,
}
if self.init_args:
data["init_args"] = pickle.dumps(self.init_args)
if self.init_kwargs:
data["init_kwargs"] = pickle.dumps(self.init_kwargs)
if self.ray_actor_options:
data["ray_actor_options"] = json.dumps(
self.ray_actor_options, cls=ServeEncoder
)
return ReplicaConfigProto(**data)

def to_proto_bytes(self):
return self.to_proto().SerializeToString()


class DeploymentMode(str, Enum):
NoServer = "NoServer"
Expand Down
Loading

0 comments on commit bdd3b9a

Please sign in to comment.