Skip to content

Commit

Permalink
scale uvicorn servers (#186)
Browse files Browse the repository at this point in the history
* threaded uvicorn run

* update

* clean

* update

* fixes

* update

* MP

* format

* single queue


* fix tests

* fix tests

* remove uvloop import

* auto loop

* format

* formatting

* remove uvloop

* fix test

* wrap ls start

* update

* fixes

* fix host

* downgrade numpy

* fix windows

* Update test.txt

Co-authored-by: Jirka Borovec <[email protected]>

* Update src/litserve/server.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Aug 2, 2024
1 parent 4d5d146 commit ba1f692
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 298 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-minimal-dependency-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: "3.10"

- name: Install LitServe
run: |
Expand Down
1 change: 1 addition & 0 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ lightning >2.0.0
torch >2.0.0
transformers
openai>=1.12.0
numpy <2.0
240 changes: 141 additions & 99 deletions src/litserve/server.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ async def options_chat_completions(self, request: Request):
return Response(status_code=200)

async def chat_completion(self, request: ChatCompletionRequest, background_tasks: BackgroundTasks):
response_queue_id = self.response_queue_id
logger.debug("Received chat completion request %s", request)

uids = [uuid.uuid4() for _ in range(request.n)]
self.queues = []
self.events = []
Expand All @@ -328,7 +328,7 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks
q = deque()
event = asyncio.Event()
self._server.response_buffer[uid] = (q, event)
self._server.request_queue.put((uid, time.monotonic(), request_el))
self._server.request_queue.put((response_queue_id, uid, time.monotonic(), request_el))
self.queues.append(q)
self.events.append(event)

Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import time
import psutil
from typing import Generator
Expand Down Expand Up @@ -97,9 +98,23 @@ def simple_batched_stream_api():
return SimpleBatchedStreamAPI()


@contextmanager
def wrap_litserve_start(server: LitServer):
server.app.response_queue_id = 0
if server.lit_spec:
server.lit_spec.response_queue_id = 0
manager, processes = server.launch_inference_worker(num_uvicorn_servers=1)
yield server
for p in processes:
p.terminate()
manager.shutdown()


@pytest.fixture()
def lit_server(simple_litapi):
return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
with wrap_litserve_start(server) as s:
yield s


@pytest.fixture()
Expand Down
11 changes: 5 additions & 6 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from litserve import LitAPI, LitServer
import litserve.server
from tests.conftest import wrap_litserve_start


class SimpleAuthedLitAPI(LitAPI):
Expand All @@ -40,17 +41,15 @@ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):

def test_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input)
assert response.status_code == 200


def test_not_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input)
assert response.status_code == 401
Expand All @@ -74,7 +73,7 @@ def test_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input)
assert response.status_code == 200
Expand All @@ -86,7 +85,7 @@ def test_not_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input)
assert response.status_code == 401
Expand Down
26 changes: 14 additions & 12 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from httpx import AsyncClient

from litserve import LitAPI, LitServer

from tests.conftest import wrap_litserve_start
import torch
import torch.nn as nn

Expand Down Expand Up @@ -86,10 +86,11 @@ async def test_batched():
api = SimpleLitAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4)

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)

assert response1.json() == {"output": 9.0}
assert response2.json() == {"output": 11.0}
Expand All @@ -99,11 +100,11 @@ async def test_batched():
async def test_unbatched():
api = SimpleLitAPI2()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1)

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)

assert response1.json() == {"output": 9.0}
assert response2.json() == {"output": 11.0}
Expand All @@ -127,8 +128,9 @@ def put(self, *args):

def test_batched_loop():
requests_queue = Queue()
requests_queue.put(("uuid-1234", time.monotonic(), {"input": 4.0}))
requests_queue.put(("uuid-1235", time.monotonic(), {"input": 5.0}))
response_queue_id = 0
requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0}))
requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0}))

lit_api_mock = MagicMock()
lit_api_mock.request_timeout = 2
Expand Down
9 changes: 4 additions & 5 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from fastapi import Request, Response
from fastapi.testclient import TestClient

from tests.conftest import wrap_litserve_start
from litserve import LitAPI, LitServer

# trivially compressible content
Expand All @@ -38,17 +38,16 @@ def encode_response(self, output) -> Response:
def test_compression():
server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

# compressed
with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
# compressed
response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={})
assert response.status_code == 200
assert response.headers["Content-Encoding"] == "gzip"
content_length = int(response.headers["Content-Length"])
assert 0 < content_length < 100000
assert response.json() == test_output

# uncompressed
with TestClient(server.app) as client:
# uncompressed
response = client.post("/predict", headers={"Accept-Encoding": ""}, json={})
assert response.status_code == 200
assert "Content-Encoding" not in response.headers
Expand Down
23 changes: 13 additions & 10 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import AsyncClient

from tests.conftest import wrap_litserve_start
import litserve as ls


@pytest.mark.asyncio()
async def test_simple_pytorch_api():
api = ls.examples.SimpleTorchAPI()
server = ls.LitServer(api, accelerator="cpu")
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 9.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 9.0}


@pytest.mark.asyncio()
async def test_simple_batched_api():
api = ls.examples.SimpleBatchedAPI()
server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.1)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}


@pytest.mark.asyncio()
async def test_simple_api():
api = ls.examples.SimpleLitAPI()
server = ls.LitServer(api)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
8 changes: 4 additions & 4 deletions tests/test_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from fastapi import Request, Response
from fastapi.testclient import TestClient
from tests.conftest import wrap_litserve_start

from litserve import LitAPI, LitServer

Expand All @@ -39,7 +40,7 @@ def test_multipart_form_data(tmp_path):
SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2)
)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file_path = f"{tmp_path}/big_file.txt"
with open(file_path, "wb") as f:
f.write(bytearray([1] * file_length))
Expand All @@ -56,7 +57,7 @@ def test_file_too_big(tmp_path):
SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2)
)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file_path = f"{tmp_path}/big_file.txt"
with open(file_path, "wb") as f:
f.write(bytearray([1] * file_length))
Expand Down Expand Up @@ -86,8 +87,7 @@ def encode_response(self, output) -> Response:

def test_urlencoded_form_data():
server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file = {"input": "4.0"}
response = client.post("/predict", data=file)
assert response.json() == {"output": 16.0}
Loading

0 comments on commit ba1f692

Please sign in to comment.