Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve][Deployment Graph] Let .bind return ray DAGNode types and remove exposing DeploymentNode as public #24065

Merged
merged 12 commits into from
Apr 21, 2022
6 changes: 6 additions & 0 deletions python/ray/experimental/dag/class_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _contains_input_node(self) -> bool:
return False

def __getattr__(self, method_name: str):
# User trying to call .bind() without a bind class method
if method_name == "bind" and "bind" not in dir(self._body):
Copy link
Member Author

@jiaodong jiaodong Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericl addressed, now we don't have the extra .bind() on @serve.deployment, consistent with Ray Core .bind now.

Only little nit i did here is to detect and surface the same DAGNode base exception message when user tries to Actor.bind().bind(), that previously we only surfaced type object 'Actor' has no attribute 'bind'

raise AttributeError(
f".bind() cannot be used again on {type(self)} "
f"(args: {self.get_args()}, kwargs: {self.get_kwargs()})."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why print args and kwargs here? Doesn't seem relevant and might be spammy.

)
# Raise an error if the method is invalid.
getattr(self._body, method_name)
call_node = _UnboundClassMethodNode(self, method_name)
Expand Down
6 changes: 0 additions & 6 deletions python/ray/experimental/dag/py_obj_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ def __init__(self):
from ray.serve.pipeline.deployment_node import DeploymentNode
from ray.serve.pipeline.deployment_method_node import DeploymentMethodNode
from ray.serve.pipeline.deployment_function_node import DeploymentFunctionNode
from ray.serve.deployment_graph import DeploymentNode as UserDeploymentNode
from ray.serve.deployment_graph import (
DeploymentFunctionNode as UserDeploymentFunctionNode,
)

self.dispatch_table[FunctionNode] = self._reduce_dag_node
self.dispatch_table[ClassNode] = self._reduce_dag_node
Expand All @@ -62,8 +58,6 @@ def __init__(self):
self.dispatch_table[DeploymentNode] = self._reduce_dag_node
self.dispatch_table[DeploymentMethodNode] = self._reduce_dag_node
self.dispatch_table[DeploymentFunctionNode] = self._reduce_dag_node
self.dispatch_table[UserDeploymentNode] = self._reduce_dag_node
self.dispatch_table[UserDeploymentFunctionNode] = self._reduce_dag_node
super().__init__(self._buf)

def find_nodes(self, obj: Any) -> List["DAGNode"]:
Expand Down
31 changes: 15 additions & 16 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from ray.util.annotations import PublicAPI
import ray
from ray import cloudpickle
from ray.serve.deployment_graph import DeploymentNode, DeploymentFunctionNode
from ray.serve.deployment_graph import ClassNode, FunctionNode
from ray.serve.application import Application

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -1246,22 +1246,22 @@ def get_deployment_statuses() -> Dict[str, DeploymentStatusInfo]:

@PublicAPI(stability="alpha")
def run(
target: Union[DeploymentNode, DeploymentFunctionNode],
target: Union[ClassNode, FunctionNode],
_blocking: bool = True,
*,
host: str = DEFAULT_HTTP_HOST,
port: int = DEFAULT_HTTP_PORT,
) -> Optional[RayServeHandle]:
"""Run a Serve application and return a ServeHandle to the ingress.
Either a DeploymentNode, DeploymentFunctionNode, or a pre-built application
Either a ClassNode, FunctionNode, or a pre-built application
can be passed in. If a node is passed in, all of the deployments it depends
on will be deployed. If there is an ingress, its handle will be returned.
Args:
target (Union[DeploymentNode, DeploymentFunctionNode, Application]):
A user-built Serve Application or a DeploymentNode that acts as the
root node of DAG. By default DeploymentNode is the Driver
target (Union[ClassNode, FunctionNode, Application]):
A user-built Serve Application or a ClassNode that acts as the
root node of DAG. By default ClassNode is the Driver
deployment unless user provides a customized one.
host (str): The host passed into serve.start().
port (int): The port passed into serve.start().
Expand All @@ -1279,12 +1279,12 @@ def run(
if isinstance(target, Application):
deployments = list(target.deployments.values())
ingress = target.ingress
# Each DAG should always provide a valid Driver DeploymentNode
elif isinstance(target, DeploymentNode):
# Each DAG should always provide a valid Driver ClassNode
elif isinstance(target, ClassNode):
deployments = pipeline_build(target)
ingress = get_and_validate_ingress_deployment(deployments)
# Special case where user is doing single function serve.run(func.bind())
elif isinstance(target, DeploymentFunctionNode):
elif isinstance(target, FunctionNode):
deployments = pipeline_build(target)
ingress = get_and_validate_ingress_deployment(deployments)
if len(deployments) != 1:
Expand All @@ -1297,15 +1297,14 @@ def run(
elif isinstance(target, DAGNode):
raise ValueError(
"Invalid DAGNode type as entry to serve.run(), "
f"type: {type(target)}, accepted: DeploymentNode, "
"DeploymentFunctionNode please provide a driver class and bind it "
f"type: {type(target)}, accepted: ClassNode, "
"FunctionNode please provide a driver class and bind it "
"as entrypoint to your Serve DAG."
)
else:
raise TypeError(
"Expected a DeploymentNode, DeploymentFunctionNode, or "
"Application as target. Got unexpected type "
f'"{type(target)}" instead.'
"Expected a ClassNode, FunctionNode, or Application as target. "
f"Got unexpected type {type(target)} instead."
)

parameter_group = []
Expand All @@ -1332,10 +1331,10 @@ def run(
return ingress.get_handle()


def build(target: Union[DeploymentNode, DeploymentFunctionNode]) -> Application:
def build(target: Union[ClassNode, FunctionNode]) -> Application:
"""Builds a Serve application into a static application.
Takes in a DeploymentNode or DeploymentFunctionNode and converts it to a
Takes in a ClassNode or FunctionNode and converts it to a
Serve application consisting of one or more deployments. This is intended
to be used for production scenarios and deployed via the Serve REST API or
CLI, so there are some restrictions placed on the deployments:
Expand Down
12 changes: 6 additions & 6 deletions python/ray/serve/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
Tuple,
Union,
)

from ray.experimental.dag.class_node import ClassNode
from ray.experimental.dag.function_node import FunctionNode
from ray.serve.config import (
AutoscalingConfig,
DeploymentConfig,
)
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.deployment_graph import DeploymentNode, DeploymentFunctionNode
from ray.serve.utils import DEFAULT, get_deployment_import_path
from ray.util.annotations import PublicAPI
from ray.serve.schema import (
Expand Down Expand Up @@ -186,8 +186,8 @@ def __call__(self):
)

@PublicAPI(stability="alpha")
def bind(self, *args, **kwargs) -> Union[DeploymentNode, DeploymentFunctionNode]:
"""Bind the provided arguments and return a DeploymentNode.
def bind(self, *args, **kwargs) -> Union[ClassNode, FunctionNode]:
"""Bind the provided arguments and return a class or function node.
The returned bound deployment can be deployed or bound to other
deployments to create a deployment graph.
Expand All @@ -200,7 +200,7 @@ def bind(self, *args, **kwargs) -> Union[DeploymentNode, DeploymentFunctionNode]
schema_shell = deployment_to_schema(copied_self)

if inspect.isfunction(self._func_or_class):
return DeploymentFunctionNode(
return FunctionNode(
self._func_or_class,
args, # Used to bind and resolve DAG only, can take user input
kwargs, # Used to bind and resolve DAG only, can take user input
Expand All @@ -211,7 +211,7 @@ def bind(self, *args, **kwargs) -> Union[DeploymentNode, DeploymentFunctionNode]
},
)
else:
return DeploymentNode(
return ClassNode(
self._func_or_class,
args,
kwargs,
Expand Down
60 changes: 7 additions & 53 deletions python/ray/serve/deployment_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from ray.experimental.dag.class_node import ClassNode
from ray.experimental.dag.function_node import FunctionNode
from ray.experimental.dag import DAGNode
from ray.experimental.dag.class_node import ClassNode # noqa: F401
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these imports happen at deployment_graph level so from api.py we're just importing from serve module, rather than directly exposing experimental folder stuff.

Ideally we should have another effort to cleanup imports while eliminating circular imports.

from ray.experimental.dag.function_node import FunctionNode # noqa: F401
from ray.experimental.dag.input_node import InputNode # noqa: F401
from ray.experimental.dag import DAGNode # noqa: F401
from ray.util.annotations import PublicAPI


Expand All @@ -14,7 +15,9 @@ class RayServeDAGHandle:
"""

def __init__(self, dag_node_json: str) -> None:
from ray.serve.pipeline.json_serde import dagnode_from_json

self.dagnode_from_json = dagnode_from_json
self.dag_node_json = dag_node_json

# NOTE(simon): Making this lazy to avoid deserialization in controller for now
Expand All @@ -31,57 +34,8 @@ def __reduce__(self):
return RayServeDAGHandle._deserialize, (self.dag_node_json,)

def remote(self, *args, **kwargs):
from ray.serve.pipeline.json_serde import dagnode_from_json

if self.dag_node is None:
self.dag_node = json.loads(
self.dag_node_json, object_hook=dagnode_from_json
self.dag_node_json, object_hook=self.dagnode_from_json
)
return self.dag_node.execute(*args, **kwargs)


@PublicAPI(stability="alpha")
class DeploymentMethodNode(DAGNode):
"""Represents a method call on a bound deployment node.
These method calls can be composed into an optimized call DAG and passed
to a "driver" deployment that will orchestrate the calls at runtime.
This class cannot be called directly. Instead, when it is bound to a
deployment node, it will be resolved to a DeployedCallGraph at runtime.
"""

# TODO (jiaodong): Later unify and refactor this with pipeline node class
pass


@PublicAPI(stability="alpha")
class DeploymentNode(ClassNode):
"""Represents a deployment with its bound config options and arguments.
The bound deployment can be run using serve.run().
A bound deployment can be passed as an argument to other bound deployments
to build a deployment graph. When the graph is deployed, the
bound deployments passed into a constructor will be converted to
RayServeHandles that can be used to send requests.
Calling deployment.method.bind() will return a DeploymentMethodNode
that can be used to compose an optimized call graph.
"""

# TODO (jiaodong): Later unify and refactor this with pipeline node class
def bind(self, *args, **kwargs):
"""Bind the default __call__ method and return a DeploymentMethodNode"""
return self.__call__.bind(*args, **kwargs)


@PublicAPI(stability="alpha")
class DeploymentFunctionNode(FunctionNode):
"""Represents a serve.deployment decorated function from user.
It's the counterpart of DeploymentNode that represents function as body
instead of class.
"""

pass
1 change: 1 addition & 0 deletions python/ray/serve/pipeline/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List

from ray.experimental.dag.dag_node import DAGNode
from ray.serve.pipeline.generate import (
transform_ray_dag_to_serve_dag,
Expand Down
20 changes: 10 additions & 10 deletions python/ray/serve/tests/test_pipeline_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import ray
from ray import serve
from ray.experimental.dag.input_node import InputNode
from ray.serve.application import Application
from ray.serve.api import build as build_app
from ray.serve.deployment_graph import DeploymentNode, RayServeDAGHandle
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve.pipeline.api import build as pipeline_build
from ray.serve.deployment_graph import ClassNode, InputNode
from ray.serve.drivers import DAGDriver
import starlette.requests

Expand All @@ -21,9 +21,7 @@
NESTED_HANDLE_KEY = "nested_handle"


def maybe_build(
node: DeploymentNode, use_build: bool
) -> Union[Application, DeploymentNode]:
def maybe_build(node: ClassNode, use_build: bool) -> Union[Application, ClassNode]:
if use_build:
return Application.from_dict(build_app(node).to_dict())
else:
Expand Down Expand Up @@ -202,7 +200,7 @@ def test_multi_instantiation_class_deployment_in_init_args(serve_instance, use_b
m1 = Model.bind(2)
m2 = Model.bind(3)
combine = Combine.bind(m1, m2=m2)
combine_output = combine.bind(dag_input)
combine_output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)

handle = serve.run(serve_dag)
Expand All @@ -215,7 +213,7 @@ def test_shared_deployment_handle(serve_instance, use_build):
with InputNode() as dag_input:
m = Model.bind(2)
combine = Combine.bind(m, m2=m)
combine_output = combine.bind(dag_input)
combine_output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(combine_output, input_schema=json_resolver)

handle = serve.run(serve_dag)
Expand All @@ -229,7 +227,7 @@ def test_multi_instantiation_class_nested_deployment_arg_dag(serve_instance, use
m1 = Model.bind(2)
m2 = Model.bind(3)
combine = Combine.bind(m1, m2={NESTED_HANDLE_KEY: m2}, m2_nested=True)
output = combine.bind(dag_input)
output = combine.__call__.bind(dag_input)
serve_dag = DAGDriver.bind(output, input_schema=json_resolver)

handle = serve.run(serve_dag)
Expand Down Expand Up @@ -418,8 +416,10 @@ def ping(self):
return "hello"

with pytest.raises(AttributeError, match=r"\.bind\(\) cannot be used again on"):
# Special for serve: Actor.bind().bind() returns DeploymentMethodNode
_ = Actor.bind().bind().bind()
_ = Actor.bind().bind()

with pytest.raises(AttributeError, match=r"\.bind\(\) cannot be used again on"):
_ = Actor.bind().ping.bind().bind()

with pytest.raises(
AttributeError,
Expand Down