-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
7f466b0
commit 11c8de2
Showing
7 changed files
with
455 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Experimental Direct Ingress | ||
|
||
In the 2.1, Serve provides an alpha version of [gRPC](https://grpc.io/) ingress. | ||
|
||
With RPC protocol, You will get: | ||
|
||
* Standardized inference request/response schema during client and serve. | ||
* High performant endpoint than HTTP protocol. | ||
|
||
In this section, you will learn how to | ||
|
||
* use Serve's built-in gRPC schema to receive client traffic | ||
* bring your own gRPC schema into your Serve application | ||
|
||
## Use Serve's Schema | ||
|
||
Serve provides a simple gRPC schema to machine learning inference workload. It is designed to be kept simple, and you are encouraged to adapt it for your own need. | ||
``` | ||
message PredictRequest { | ||
map<string, bytes> input = 2; | ||
} | ||
message PredictResponse { | ||
bytes prediction = 1; | ||
} | ||
service PredictAPIsService { | ||
rpc Predict(PredictRequest) returns (PredictResponse); | ||
} | ||
``` | ||
|
||
Take a look at the following code samples for using `DefaultgRPCDriver` in Ray Serve. | ||
|
||
To implement the Serve, your class needs to inherit `ray.serve.drivers.DefaultgRPCDriver`. | ||
```{literalinclude} ../serve/doc_code/direct_ingress.py | ||
:start-after: __begin_server__ | ||
:end-before: __end_server__ | ||
:language: python | ||
``` | ||
|
||
Client: | ||
You can use Serve's built-in gRPC client to send query to the model. | ||
|
||
```{literalinclude} ../serve/doc_code/direct_ingress.py | ||
:start-after: __begin_client__ | ||
:end-before: __end_client__ | ||
:language: python | ||
``` | ||
|
||
:::{note} | ||
* `input` is a dictionary of `map<string, bytes> ` following the schema described above. | ||
* The user input data needs to be serialized to `bytes` type and fed into the `input`. | ||
* The response will be under `bytes` type, which means the user code is responsible for serializing the output into bytes. | ||
* By default, the gRPC port is 9000. You can change it by passing port number when calling DefaultgRPCDriver bind function. | ||
* If the serialization/deserialization cost is huge and unnecessary, you can also bring your own schema to use! Checkout [Bring your own schema](bring-your-own-schema) section! | ||
* There is no difference of scaling config for your business code in gRPC case, you can set the config scaling/autoscaling config inside the `serve.deployment` decorator. | ||
::: | ||
|
||
### Client schema code generation | ||
You can use the client either by importing it from the `ray` Python package. Alternatively, you can just copy [Serve's protobuf file](https://github.com/ray-project/ray/blob/e16f49b327bbc1c18e8fc5d0ac4fa8c2f1144412/src/ray/protobuf/serve.proto#L214-L225) to generate the gRPC client. | ||
|
||
* Install the gRPC code generation tools | ||
``` | ||
pip install grpcio-tools | ||
``` | ||
|
||
* Generate gRPC code based on the schema | ||
``` | ||
python -m grpc_tools.protoc --proto_path=src/ray/protobuf/ --python_out=. --grpc_python_out=. src/ray/protobuf/serve.proto | ||
``` | ||
After the two steps above, you should have `serve_pb2.py` and `serve_pb2_grpc.py` files generated. | ||
|
||
(bring-your-own-schema)= | ||
|
||
## Bring your own schema | ||
|
||
If you have a customized schema to use, Serve also supports it! | ||
|
||
Assume you have the following customized schema and have generated the corresponding gRPC code: | ||
|
||
|
||
``` | ||
message PingRequest { | ||
bool no_reply = 1; | ||
} | ||
message PingReply { | ||
} | ||
message PingTimeoutRequest {} | ||
message PingTimeoutReply {} | ||
service TestService { | ||
rpc Ping(PingRequest) returns (PingReply); | ||
rpc PingTimeout(PingTimeoutRequest) returns (PingTimeoutReply); | ||
} | ||
``` | ||
|
||
After the code is generated, you can implement the business logic for gRPC server by creating a subclass of the generated `TestServiceServicer`, and then you just need two extra steps to adopt your schema into Ray Serve. | ||
|
||
* Inherit `ray.serve.drivers.gRPCIngress` in your implementation class. | ||
* Add the `@serve.deployment(is_driver_deployment=True)` decorator. | ||
|
||
Server: | ||
```{literalinclude} ../serve/doc_code/direct_ingress_with_customized_schema.py | ||
:start-after: __begin_server__ | ||
:end-before: __end_server__ | ||
:language: python | ||
``` | ||
|
||
Client: | ||
You can directly use the client code to play it! | ||
```{literalinclude} ../serve/doc_code/direct_ingress_with_customized_schema.py | ||
:start-after: __begin_client__ | ||
:end-before: __end_client__ | ||
:language: python | ||
``` | ||
|
||
:::{note} | ||
* `is_driver_deployment` (experimental flag) is needed to mark the class as driver, serve will make sure the driver class deployment gets deployed one replica per node. | ||
* `gRPCIngress` is used for starting a gRPC server. Your driver class needs to inherit from it. | ||
::: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# flake8: noqa | ||
|
||
# __begin_server__ | ||
import ray | ||
from ray import serve | ||
from ray.serve.drivers import DefaultgRPCDriver | ||
from ray.serve.handle import RayServeDeploymentHandle | ||
from ray.serve.deployment_graph import InputNode | ||
from typing import Dict | ||
import struct | ||
|
||
|
||
@serve.deployment | ||
class FruitMarket: | ||
def __init__( | ||
self, | ||
orange_stand: RayServeDeploymentHandle, | ||
apple_stand: RayServeDeploymentHandle, | ||
): | ||
self.directory = { | ||
"ORANGE": orange_stand, | ||
"APPLE": apple_stand, | ||
} | ||
|
||
async def check_price(self, inputs: Dict[str, bytes]) -> float: | ||
costs = 0 | ||
for fruit, amount in inputs.items(): | ||
if fruit not in self.directory: | ||
return | ||
fruit_stand = self.directory[fruit] | ||
ref: ray.ObjectRef = await fruit_stand.remote(int(amount)) | ||
result = await ref | ||
costs += result | ||
return bytearray(struct.pack("f", costs)) | ||
|
||
|
||
@serve.deployment | ||
class OrangeStand: | ||
def __init__(self): | ||
self.price = 2.0 | ||
|
||
def __call__(self, num_oranges: int): | ||
return num_oranges * self.price | ||
|
||
|
||
@serve.deployment | ||
class AppleStand: | ||
def __init__(self): | ||
self.price = 3.0 | ||
|
||
def __call__(self, num_oranges: int): | ||
return num_oranges * self.price | ||
|
||
|
||
with InputNode() as input: | ||
orange_stand = OrangeStand.bind() | ||
apple_stand = AppleStand.bind() | ||
fruit_market = FruitMarket.bind(orange_stand, apple_stand) | ||
my_deployment = DefaultgRPCDriver.bind(fruit_market.check_price.bind(input)) | ||
|
||
serve.run(my_deployment) | ||
# __end_server__ | ||
|
||
# __begin_client__ | ||
import grpc | ||
from ray.serve.generated import serve_pb2, serve_pb2_grpc | ||
import asyncio | ||
import struct | ||
|
||
|
||
async def send_request(): | ||
async with grpc.aio.insecure_channel("localhost:9000") as channel: | ||
stub = serve_pb2_grpc.PredictAPIsServiceStub(channel) | ||
response = await stub.Predict( | ||
serve_pb2.PredictRequest( | ||
input={"ORANGE": bytes("10", "utf-8"), "APPLE": bytes("3", "utf-8")} | ||
) | ||
) | ||
return response | ||
|
||
|
||
async def main(): | ||
resp = await send_request() | ||
print(struct.unpack("f", resp.prediction)) | ||
|
||
|
||
asyncio.get_event_loop().run_until_complete(main()) | ||
|
||
# __end_client__ |
32 changes: 32 additions & 0 deletions
32
doc/source/serve/doc_code/direct_ingress_with_customized_schema.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# flake8: noqa | ||
|
||
# __begin_server__ | ||
from ray import serve | ||
from ray.serve.drivers import gRPCIngress | ||
import test_service_pb2_grpc, test_service_pb2 | ||
|
||
|
||
@serve.deployment(is_driver_deployment=True) | ||
class MyDriver(test_service_pb2_grpc.TestServiceServicer, gRPCIngress): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
async def Ping(self, request, context): | ||
# play with your dag and then reply | ||
return test_service_pb2.PingReply() | ||
|
||
|
||
my_deployment = MyDriver.bind() | ||
|
||
serve.run(my_deployment) | ||
# __end_server__ | ||
|
||
|
||
# __begin_client__ | ||
import grpc | ||
import test_service_pb2_grpc, test_service_pb2 | ||
|
||
channel = grpc.aio.insecure_channel("localhost:9000") | ||
stub = test_service_pb2_grpc.TestServiceStub(channel) | ||
response = stub.Ping(test_service_pb2.PingRequest()) | ||
# __end_client__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# flake8: noqa | ||
|
||
# -*- coding: utf-8 -*- | ||
# Generated by the protocol buffer compiler. DO NOT EDIT! | ||
# source: src/ray/protobuf/test_service.proto | ||
"""Generated protocol buffer code.""" | ||
from google.protobuf import descriptor as _descriptor | ||
from google.protobuf import descriptor_pool as _descriptor_pool | ||
from google.protobuf import message as _message | ||
from google.protobuf import reflection as _reflection | ||
from google.protobuf import symbol_database as _symbol_database | ||
|
||
# @@protoc_insertion_point(imports) | ||
|
||
_sym_db = _symbol_database.Default() | ||
|
||
|
||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( | ||
b'\n#src/ray/protobuf/test_service.proto\x12\x07ray.rpc"(\n\x0bPingRequest\x12\x19\n\x08no_reply\x18\x01 \x01(\x08R\x07noReply"\x0b\n\tPingReply"\x14\n\x12PingTimeoutRequest"\x12\n\x10PingTimeoutReply2\x86\x01\n\x0bTestService\x12\x30\n\x04Ping\x12\x14.ray.rpc.PingRequest\x1a\x12.ray.rpc.PingReply\x12\x45\n\x0bPingTimeout\x12\x1b.ray.rpc.PingTimeoutRequest\x1a\x19.ray.rpc.PingTimeoutReplyb\x06proto3' | ||
) | ||
|
||
|
||
_PINGREQUEST = DESCRIPTOR.message_types_by_name["PingRequest"] | ||
_PINGREPLY = DESCRIPTOR.message_types_by_name["PingReply"] | ||
_PINGTIMEOUTREQUEST = DESCRIPTOR.message_types_by_name["PingTimeoutRequest"] | ||
_PINGTIMEOUTREPLY = DESCRIPTOR.message_types_by_name["PingTimeoutReply"] | ||
PingRequest = _reflection.GeneratedProtocolMessageType( | ||
"PingRequest", | ||
(_message.Message,), | ||
{ | ||
"DESCRIPTOR": _PINGREQUEST, | ||
"__module__": "src.ray.protobuf.test_service_pb2" | ||
# @@protoc_insertion_point(class_scope:ray.rpc.PingRequest) | ||
}, | ||
) | ||
_sym_db.RegisterMessage(PingRequest) | ||
|
||
PingReply = _reflection.GeneratedProtocolMessageType( | ||
"PingReply", | ||
(_message.Message,), | ||
{ | ||
"DESCRIPTOR": _PINGREPLY, | ||
"__module__": "src.ray.protobuf.test_service_pb2" | ||
# @@protoc_insertion_point(class_scope:ray.rpc.PingReply) | ||
}, | ||
) | ||
_sym_db.RegisterMessage(PingReply) | ||
|
||
PingTimeoutRequest = _reflection.GeneratedProtocolMessageType( | ||
"PingTimeoutRequest", | ||
(_message.Message,), | ||
{ | ||
"DESCRIPTOR": _PINGTIMEOUTREQUEST, | ||
"__module__": "src.ray.protobuf.test_service_pb2" | ||
# @@protoc_insertion_point(class_scope:ray.rpc.PingTimeoutRequest) | ||
}, | ||
) | ||
_sym_db.RegisterMessage(PingTimeoutRequest) | ||
|
||
PingTimeoutReply = _reflection.GeneratedProtocolMessageType( | ||
"PingTimeoutReply", | ||
(_message.Message,), | ||
{ | ||
"DESCRIPTOR": _PINGTIMEOUTREPLY, | ||
"__module__": "src.ray.protobuf.test_service_pb2" | ||
# @@protoc_insertion_point(class_scope:ray.rpc.PingTimeoutReply) | ||
}, | ||
) | ||
_sym_db.RegisterMessage(PingTimeoutReply) | ||
|
||
_TESTSERVICE = DESCRIPTOR.services_by_name["TestService"] | ||
if _descriptor._USE_C_DESCRIPTORS == False: | ||
|
||
DESCRIPTOR._options = None | ||
_PINGREQUEST._serialized_start = 48 | ||
_PINGREQUEST._serialized_end = 88 | ||
_PINGREPLY._serialized_start = 90 | ||
_PINGREPLY._serialized_end = 101 | ||
_PINGTIMEOUTREQUEST._serialized_start = 103 | ||
_PINGTIMEOUTREQUEST._serialized_end = 123 | ||
_PINGTIMEOUTREPLY._serialized_start = 125 | ||
_PINGTIMEOUTREPLY._serialized_end = 143 | ||
_TESTSERVICE._serialized_start = 146 | ||
_TESTSERVICE._serialized_end = 280 | ||
# @@protoc_insertion_point(module_scope) |
Oops, something went wrong.