-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* refactor tests * parameterize ptype tests
- Loading branch information
1 parent
4e372b0
commit 6d0e199
Showing
14 changed files
with
186 additions
and
379 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,24 @@ | ||
from importlib_resources import files as _files | ||
|
||
sources = { | ||
"mtcars": _files("vetiver") / "data/mtcars.csv", | ||
"chicago": _files("vetiver") / "data/chicago.csv", | ||
"sacramento": _files("vetiver") / "data/sacramento.csv", | ||
} | ||
__all__ = [ | ||
"mtcars", | ||
"chicago", | ||
"sacramento", | ||
] | ||
|
||
|
||
def __dir__(): | ||
return list(sources) | ||
return __all__ | ||
|
||
|
||
def __getattr__(k): | ||
def _load_data_csv(name): | ||
import pandas as pd | ||
import pkg_resources | ||
|
||
fname = pkg_resources.resource_filename("vetiver.data", f"{name}.csv") | ||
return pd.read_csv(fname) | ||
|
||
|
||
def __getattr__(name): | ||
if name not in __all__: | ||
raise AttributeError(f"No dataset named: {name}") | ||
|
||
f_path = sources.get("mtcars") | ||
if k == "chicago": | ||
f_path = sources.get("chicago") | ||
elif k == "sacramento": | ||
f_path = sources.get("sacramento") | ||
return pd.read_csv(f_path) | ||
return _load_data_csv(name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pytest | ||
from vetiver import VetiverModel, VetiverAPI | ||
from vetiver.helpers import api_data_to_frame | ||
from starlette.testclient import TestClient | ||
|
||
|
||
def sum_values(x): | ||
return x.sum().to_list() | ||
|
||
|
||
def sum_values_no_prototype(x): | ||
return api_data_to_frame(x).sum().to_list() | ||
|
||
|
||
@pytest.fixture | ||
def client(model: VetiverModel) -> TestClient: | ||
app = VetiverAPI(model, check_prototype=True) | ||
app.vetiver_post(sum_values, "sum") | ||
client = TestClient(app.app) | ||
|
||
return client | ||
|
||
|
||
@pytest.fixture | ||
def client_no_prototype(model: VetiverModel) -> TestClient: | ||
app = VetiverAPI(model, check_prototype=False) | ||
app.vetiver_post(sum_values_no_prototype, "sum") | ||
client = TestClient(app.app) | ||
|
||
return client |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,32 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from fastapi.testclient import TestClient | ||
|
||
from vetiver import mock, VetiverModel, VetiverAPI | ||
from vetiver.helpers import api_data_to_frame | ||
import vetiver | ||
from vetiver import mock, VetiverModel | ||
|
||
|
||
@pytest.fixture | ||
def vetiver_model(): | ||
np.random.seed(500) | ||
@pytest.fixture() | ||
def model(): | ||
X, y = mock.get_mock_data() | ||
model = mock.get_mock_model().fit(X, y) | ||
v = VetiverModel( | ||
model=model, | ||
prototype_data=X, | ||
model_name="my_model", | ||
versioned=None, | ||
description="A regression model for testing purposes", | ||
) | ||
|
||
return v | ||
|
||
|
||
def sum_values(x): | ||
return x.sum().to_list() | ||
|
||
|
||
def sum_values_no_prototype(x): | ||
return api_data_to_frame(x).sum().to_list() | ||
|
||
model = mock.get_mock_model() | ||
|
||
@pytest.fixture | ||
def vetiver_client(vetiver_model): # With check_prototype=True | ||
|
||
app = VetiverAPI(vetiver_model, check_prototype=True) | ||
app.vetiver_post(sum_values, "sum") | ||
|
||
app.app.root_path = "/sum" | ||
client = TestClient(app.app) | ||
|
||
return client | ||
return VetiverModel(model.fit(X, y), "model", prototype_data=X) | ||
|
||
|
||
@pytest.fixture | ||
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False | ||
|
||
app = VetiverAPI(vetiver_model, check_prototype=False) | ||
app.vetiver_post(sum_values_no_prototype, "sum") | ||
def data() -> pd.DataFrame: | ||
return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) | ||
|
||
app.app.root_path = "/sum" | ||
client = TestClient(app.app) | ||
|
||
return client | ||
|
||
|
||
def test_endpoint_adds_ptype(vetiver_client): | ||
|
||
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) | ||
response = vetiver.predict(endpoint=vetiver_client, data=data) | ||
def test_endpoint_adds(client, data): | ||
response = client.post("/sum/", data=data.to_json(orient="records")) | ||
|
||
assert isinstance(response, pd.DataFrame) | ||
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json() | ||
assert response.status_code == 200 | ||
assert response.json() == {"sum": [3, 6, 9]} | ||
|
||
|
||
def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false): | ||
def test_endpoint_adds_no_prototype(client_no_prototype, data): | ||
|
||
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) | ||
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data) | ||
response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) | ||
|
||
assert isinstance(response, pd.DataFrame) | ||
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json() | ||
assert response.status_code == 200 | ||
assert response.json() == {"sum": [3, 6, 9]} |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.