Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale uvicorn servers #186

Merged
merged 46 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d348e79
threaded uvicorn run
aniketmaurya Jul 31, 2024
ba63390
update
aniketmaurya Jul 31, 2024
5bbd1ab
update
aniketmaurya Jul 31, 2024
4592ecf
update
aniketmaurya Jul 31, 2024
2c2177b
update
aniketmaurya Jul 31, 2024
dec04bd
update
aniketmaurya Jul 31, 2024
16c0444
update
aniketmaurya Jul 31, 2024
7254b88
update
aniketmaurya Jul 31, 2024
b4c653e
clean
aniketmaurya Jul 31, 2024
57c0cb1
update
aniketmaurya Jul 31, 2024
48a1cf0
update
aniketmaurya Jul 31, 2024
058f4a7
update
aniketmaurya Jul 31, 2024
593c14f
fixes
aniketmaurya Jul 31, 2024
c9a938c
update
aniketmaurya Jul 31, 2024
43361f8
MP
aniketmaurya Jul 31, 2024
7a7a957
update
aniketmaurya Jul 31, 2024
e0fb5c9
update
aniketmaurya Aug 1, 2024
e4b5c98
format
aniketmaurya Aug 1, 2024
cc5720a
update
aniketmaurya Aug 1, 2024
e38f300
single queue
aniketmaurya Aug 1, 2024
0ee56d8
update
aniketmaurya Aug 1, 2024
9743ac8
update
aniketmaurya Aug 1, 2024
a80523f
clean up
aniketmaurya Aug 1, 2024
2215fc1
update
aniketmaurya Aug 1, 2024
68dcc06
update
aniketmaurya Aug 1, 2024
8a9e790
update
aniketmaurya Aug 1, 2024
cf6a478
fix tests
aniketmaurya Aug 1, 2024
c30815b
fix tests
aniketmaurya Aug 1, 2024
06257af
fix
aniketmaurya Aug 1, 2024
78de09e
remove uvloop import
aniketmaurya Aug 1, 2024
ded829b
auto loop
aniketmaurya Aug 1, 2024
ea9f2b4
format
aniketmaurya Aug 1, 2024
4eb8cba
formatting
aniketmaurya Aug 1, 2024
27f7bba
remove uvloop
aniketmaurya Aug 1, 2024
41cb99b
fix test
aniketmaurya Aug 1, 2024
b86afa1
wrap ls start
aniketmaurya Aug 1, 2024
1e2307e
update
aniketmaurya Aug 1, 2024
9f48b5a
fixes
aniketmaurya Aug 1, 2024
950f457
fix host
aniketmaurya Aug 1, 2024
c462ece
downgrade numpy
aniketmaurya Aug 1, 2024
3827ccc
fix windows
aniketmaurya Aug 1, 2024
e6e04ff
Update test.txt
aniketmaurya Aug 1, 2024
9cf3c22
Merge branch 'main' into aniket/parallel-servers
aniketmaurya Aug 2, 2024
264486f
Update src/litserve/server.py
lantiga Aug 2, 2024
2941949
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
7597b16
upgrade Python
aniketmaurya Aug 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ lightning >2.0.0
torch >2.0.0
transformers
openai>=1.12.0
numpy <2.0
242 changes: 143 additions & 99 deletions src/litserve/server.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ async def options_chat_completions(self, request: Request):
return Response(status_code=200)

async def chat_completion(self, request: ChatCompletionRequest, background_tasks: BackgroundTasks):
response_queue_id = self.response_queue_id
logger.debug("Received chat completion request %s", request)

uids = [uuid.uuid4() for _ in range(request.n)]
self.queues = []
self.events = []
Expand All @@ -328,7 +328,7 @@ async def chat_completion(self, request: ChatCompletionRequest, background_tasks
q = deque()
event = asyncio.Event()
self._server.response_buffer[uid] = (q, event)
self._server.request_queue.put((uid, time.monotonic(), request_el))
self._server.request_queue.put((response_queue_id, uid, time.monotonic(), request_el))
self.queues.append(q)
self.events.append(event)

Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import time
import psutil
from typing import Generator
Expand Down Expand Up @@ -97,9 +98,23 @@ def simple_batched_stream_api():
return SimpleBatchedStreamAPI()


@contextmanager
def wrap_litserve_start(server: LitServer):
server.app.response_queue_id = 0
if server.lit_spec:
server.lit_spec.response_queue_id = 0
manager, processes = server.launch_inference_worker(num_uvicorn_servers=1)
yield server
for p in processes:
p.terminate()
manager.shutdown()


@pytest.fixture()
def lit_server(simple_litapi):
return LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
with wrap_litserve_start(server) as s:
yield s


@pytest.fixture()
Expand Down
11 changes: 5 additions & 6 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from litserve import LitAPI, LitServer
import litserve.server
from tests.conftest import wrap_litserve_start


class SimpleAuthedLitAPI(LitAPI):
Expand All @@ -40,17 +41,15 @@ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):

def test_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer 1234"}, json=input)
assert response.status_code == 200


def test_not_authorized_custom():
server = LitServer(SimpleAuthedLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"Authorization": "Bearer wrong"}, json=input)
assert response.status_code == 401
Expand All @@ -74,7 +73,7 @@ def test_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "abcd"}, json=input)
assert response.status_code == 200
Expand All @@ -86,7 +85,7 @@ def test_not_authorized_api_key():
litserve.server.LIT_SERVER_API_KEY = "abcd"
server = LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
input = {"input": 4.0}
response = client.post("/predict", headers={"X-API-Key": "wrong"}, json=input)
assert response.status_code == 401
Expand Down
26 changes: 14 additions & 12 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from httpx import AsyncClient

from litserve import LitAPI, LitServer

from tests.conftest import wrap_litserve_start
import torch
import torch.nn as nn

Expand Down Expand Up @@ -86,10 +86,11 @@ async def test_batched():
api = SimpleLitAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4)

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)

assert response1.json() == {"output": 9.0}
assert response2.json() == {"output": 11.0}
Expand All @@ -99,11 +100,11 @@ async def test_batched():
async def test_unbatched():
api = SimpleLitAPI2()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1)

async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response1 = ac.post("/predict", json={"input": 4.0})
response2 = ac.post("/predict", json={"input": 5.0})
response1, response2 = await asyncio.gather(response1, response2)

assert response1.json() == {"output": 9.0}
assert response2.json() == {"output": 11.0}
Expand All @@ -127,8 +128,9 @@ def put(self, *args):

def test_batched_loop():
requests_queue = Queue()
requests_queue.put(("uuid-1234", time.monotonic(), {"input": 4.0}))
requests_queue.put(("uuid-1235", time.monotonic(), {"input": 5.0}))
response_queue_id = 0
requests_queue.put((response_queue_id, "uuid-1234", time.monotonic(), {"input": 4.0}))
requests_queue.put((response_queue_id, "uuid-1235", time.monotonic(), {"input": 5.0}))

lit_api_mock = MagicMock()
lit_api_mock.request_timeout = 2
Expand Down
9 changes: 4 additions & 5 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from fastapi import Request, Response
from fastapi.testclient import TestClient

from tests.conftest import wrap_litserve_start
from litserve import LitAPI, LitServer

# trivially compressible content
Expand All @@ -38,17 +38,16 @@ def encode_response(self, output) -> Response:
def test_compression():
server = LitServer(LargeOutputLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

# compressed
with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
# compressed
response = client.post("/predict", headers={"Accept-Encoding": "gzip"}, json={})
assert response.status_code == 200
assert response.headers["Content-Encoding"] == "gzip"
content_length = int(response.headers["Content-Length"])
assert 0 < content_length < 100000
assert response.json() == test_output

# uncompressed
with TestClient(server.app) as client:
# uncompressed
response = client.post("/predict", headers={"Accept-Encoding": ""}, json={})
assert response.status_code == 200
assert "Content-Encoding" not in response.headers
Expand Down
23 changes: 13 additions & 10 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
import pytest
from asgi_lifespan import LifespanManager
from httpx import AsyncClient

from tests.conftest import wrap_litserve_start
import litserve as ls


@pytest.mark.asyncio()
async def test_simple_pytorch_api():
api = ls.examples.SimpleTorchAPI()
server = ls.LitServer(api, accelerator="cpu")
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 9.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 9.0}


@pytest.mark.asyncio()
async def test_simple_batched_api():
api = ls.examples.SimpleBatchedAPI()
server = ls.LitServer(api, max_batch_size=4, batch_timeout=0.1)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}


@pytest.mark.asyncio()
async def test_simple_api():
api = ls.examples.SimpleLitAPI()
server = ls.LitServer(api)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
response = await ac.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
8 changes: 4 additions & 4 deletions tests/test_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from fastapi import Request, Response
from fastapi.testclient import TestClient
from tests.conftest import wrap_litserve_start

from litserve import LitAPI, LitServer

Expand All @@ -39,7 +40,7 @@ def test_multipart_form_data(tmp_path):
SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length * 2)
)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file_path = f"{tmp_path}/big_file.txt"
with open(file_path, "wb") as f:
f.write(bytearray([1] * file_length))
Expand All @@ -56,7 +57,7 @@ def test_file_too_big(tmp_path):
SimpleFileLitAPI(), accelerator="cpu", devices=1, workers_per_device=1, max_payload_size=(file_length / 2)
)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file_path = f"{tmp_path}/big_file.txt"
with open(file_path, "wb") as f:
f.write(bytearray([1] * file_length))
Expand Down Expand Up @@ -86,8 +87,7 @@ def encode_response(self, output) -> Response:

def test_urlencoded_form_data():
server = LitServer(SimpleFormLitAPI(), accelerator="cpu", devices=1, workers_per_device=1)

with TestClient(server.app) as client:
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
file = {"input": "4.0"}
response = client.post("/predict", data=file)
assert response.json() == {"output": 16.0}
Loading
Loading