Skip to content

Commit

Permalink
[Serve] Ensure SimpleSchemaIngress uses FastAPI custom serializers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored May 6, 2022
1 parent 6680494 commit 95c11c9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/ray/serve/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod
from typing import Any, Callable, Optional, Type, Union
from pydantic import BaseModel
from ray.serve.utils import install_serve_encoders_to_fastapi

import starlette
from fastapi import Body, Depends, FastAPI
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
resolver. When you pass in a string, Serve will import it.
Please refer to Serve HTTP adatper documentation to learn more.
"""
install_serve_encoders_to_fastapi()
http_adapter = load_http_adapter(http_adapter)
self.app = FastAPI()

Expand Down
14 changes: 14 additions & 0 deletions python/ray/serve/tests/test_pipeline_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import numpy as np
from pydantic import BaseModel

import pytest
Expand Down Expand Up @@ -156,5 +157,18 @@ def test_dag_driver_partial_input(serve_instance):
assert resp.json() == [1, 2, [3, 4]]


@serve.deployment
def return_np_int(_):
return [np.int64(42)]


def test_driver_np_serializer(serve_instance):
# https://github.com/ray-project/ray/pull/24215#issuecomment-1115237058
with InputNode() as inp:
dag = DAGDriver.bind(return_np_int.bind(inp))
serve.run(dag)
assert requests.get("http://127.0.0.1:8000/").json() == [42]


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
7 changes: 7 additions & 0 deletions python/ray/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pydantic
import pydantic.json
import fastapi.encoders

import ray
import ray.serialization_addons
Expand Down Expand Up @@ -80,6 +81,12 @@ def install_serve_encoders_to_fastapi():
"""Inject Serve's encoders so FastAPI's jsonable_encoder can pick it up."""
# https://stackoverflow.com/questions/62311401/override-default-encoders-for-jsonable-encoder-in-fastapi # noqa
pydantic.json.ENCODERS_BY_TYPE.update(serve_encoders)
# FastAPI cache these encoders at import time, so we also needs to refresh it.
fastapi.encoders.encoders_by_class_tuples = (
fastapi.encoders.generate_encoders_by_class_tuples(
pydantic.json.ENCODERS_BY_TYPE
)
)


@ray.remote(num_cpus=0)
Expand Down

0 comments on commit 95c11c9

Please sign in to comment.