Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup: fix test naming convention #190

Merged
merged 5 commits into from
Aug 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from asgi_lifespan import LifespanManager
from fastapi import Request, Response
from httpx import AsyncClient

from litserve import LitAPI, LitServer
from tests.conftest import wrap_litserve_start
from litserve.server import run_batched_loop
from tests.conftest import wrap_litserve_start
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved


class Linear(nn.Module):
Expand All @@ -38,7 +39,7 @@ def forward(self, x):
return self.linear(x)


class SimpleLitAPI(LitAPI):
class SimpleBatchLitAPI(LitAPI):
def setup(self, device):
self.model = Linear().to(device)
self.device = device
Expand All @@ -63,7 +64,7 @@ def encode_response(self, output) -> Response:
return {"output": float(output)}


class SimpleLitAPI2(LitAPI):
class SimpleTorchAPI(LitAPI):
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
def setup(self, device):
self.model = Linear().to(device)
self.device = device
Expand All @@ -82,7 +83,7 @@ def encode_response(self, output) -> Response:

@pytest.mark.asyncio()
async def test_batched():
api = SimpleLitAPI()
api = SimpleBatchLitAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=2, batch_timeout=4)

with wrap_litserve_start(server) as server:
Expand All @@ -97,7 +98,7 @@ async def test_batched():

@pytest.mark.asyncio()
async def test_unbatched():
api = SimpleLitAPI2()
api = SimpleTorchAPI()
server = LitServer(api, accelerator="cpu", devices=1, timeout=10, max_batch_size=1)
with wrap_litserve_start(server) as server:
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
Expand All @@ -111,13 +112,13 @@ async def test_unbatched():

def test_max_batch_size():
with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=0)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=0)

with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=-1)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=-1)

with pytest.raises(ValueError, match="must be"):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2, batch_timeout=5)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2, batch_timeout=5)


def test_max_batch_size_warning():
Expand All @@ -126,21 +127,21 @@ def test_max_batch_size_warning():
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2)

# Test no warnings are raised when max_batch_size is set
# Test no warnings are raised when max_batch_size is set and max_batch_size is not set
with pytest.raises(pytest.fail.Exception), pytest.warns(
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2)
LitServer(SimpleBatchLitAPI(), accelerator="cpu", devices=1, timeout=2, max_batch_size=2)

# Test no max_batch_size warnings are raised with a different API
# Test no warning is set when LitAPI doesn't implement batch and unbatch
with pytest.raises(pytest.fail.Exception), pytest.warns(
UserWarning,
match=warning,
):
LitServer(SimpleLitAPI2(), accelerator="cpu", devices=1, timeout=2)
LitServer(SimpleTorchAPI(), accelerator="cpu", devices=1, timeout=2)


class FakeResponseQueue:
Expand Down
Loading