Skip to content

Commit

Permalink
refactor tests (#209)
Browse files Browse the repository at this point in the history
* refactor tests

* parameterize ptype tests
  • Loading branch information
isabelizimm authored Mar 25, 2024
1 parent 4e372b0 commit 6d0e199
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 379 deletions.
32 changes: 17 additions & 15 deletions vetiver/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from importlib_resources import files as _files

sources = {
"mtcars": _files("vetiver") / "data/mtcars.csv",
"chicago": _files("vetiver") / "data/chicago.csv",
"sacramento": _files("vetiver") / "data/sacramento.csv",
}
__all__ = [
"mtcars",
"chicago",
"sacramento",
]


def __dir__():
return list(sources)
return __all__


def __getattr__(k):
def _load_data_csv(name):
import pandas as pd
import pkg_resources

fname = pkg_resources.resource_filename("vetiver.data", f"{name}.csv")
return pd.read_csv(fname)


def __getattr__(name):
if name not in __all__:
raise AttributeError(f"No dataset named: {name}")

f_path = sources.get("mtcars")
if k == "chicago":
f_path = sources.get("chicago")
elif k == "sacramento":
f_path = sources.get("sacramento")
return pd.read_csv(f_path)
return _load_data_csv(name)
7 changes: 3 additions & 4 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uvicorn
import logging
import pandas as pd
from fastapi import FastAPI, Request, testclient
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
Expand Down Expand Up @@ -321,9 +321,8 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
>>> endpoint = vetiver.vetiver_endpoint(url='http://127.0.0.1:8000/predict')
>>> vetiver.predict(endpoint, X) # doctest: +SKIP
"""
if isinstance(endpoint, testclient.TestClient):
requester = endpoint
endpoint = requester.app.root_path
if "test_client" in kw:
requester = kw.pop("test_client")
else:
requester = requests

Expand Down
Empty file added vetiver/tests/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions vetiver/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from vetiver import VetiverModel, VetiverAPI
from vetiver.helpers import api_data_to_frame
from starlette.testclient import TestClient


def sum_values(x):
return x.sum().to_list()


def sum_values_no_prototype(x):
return api_data_to_frame(x).sum().to_list()


@pytest.fixture
def client(model: VetiverModel) -> TestClient:
app = VetiverAPI(model, check_prototype=True)
app.vetiver_post(sum_values, "sum")
client = TestClient(app.app)

return client


@pytest.fixture
def client_no_prototype(model: VetiverModel) -> TestClient:
app = VetiverAPI(model, check_prototype=False)
app.vetiver_post(sum_values_no_prototype, "sum")
client = TestClient(app.app)

return client
73 changes: 15 additions & 58 deletions vetiver/tests/test_add_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,32 @@
import pytest

import numpy as np
import pandas as pd
from fastapi.testclient import TestClient

from vetiver import mock, VetiverModel, VetiverAPI
from vetiver.helpers import api_data_to_frame
import vetiver
from vetiver import mock, VetiverModel


@pytest.fixture
def vetiver_model():
np.random.seed(500)
@pytest.fixture()
def model():
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
model=model,
prototype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)

return v


def sum_values(x):
return x.sum().to_list()


def sum_values_no_prototype(x):
return api_data_to_frame(x).sum().to_list()

model = mock.get_mock_model()

@pytest.fixture
def vetiver_client(vetiver_model): # With check_prototype=True

app = VetiverAPI(vetiver_model, check_prototype=True)
app.vetiver_post(sum_values, "sum")

app.app.root_path = "/sum"
client = TestClient(app.app)

return client
return VetiverModel(model.fit(X, y), "model", prototype_data=X)


@pytest.fixture
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False

app = VetiverAPI(vetiver_model, check_prototype=False)
app.vetiver_post(sum_values_no_prototype, "sum")
def data() -> pd.DataFrame:
return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})

app.app.root_path = "/sum"
client = TestClient(app.app)

return client


def test_endpoint_adds_ptype(vetiver_client):

data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client, data=data)
def test_endpoint_adds(client, data):
response = client.post("/sum/", data=data.to_json(orient="records"))

assert isinstance(response, pd.DataFrame)
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()
assert response.status_code == 200
assert response.json() == {"sum": [3, 6, 9]}


def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false):
def test_endpoint_adds_no_prototype(client_no_prototype, data):

data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)
response = client_no_prototype.post("/sum/", data=data.to_json(orient="records"))

assert isinstance(response, pd.DataFrame)
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()
assert response.status_code == 200
assert response.json() == {"sum": [3, 6, 9]}
30 changes: 0 additions & 30 deletions vetiver/tests/test_build_api.py

This file was deleted.

Loading

0 comments on commit 6d0e199

Please sign in to comment.