Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make vetiver_post extensible for all endpoints #130

Merged
merged 9 commits into from
Nov 29, 2022
84 changes: 35 additions & 49 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ def __init__(
self.model = model
self.check_ptype = check_ptype
self.app_factory = app_factory
self.app = self._init_app()
self.app = app_factory()

self._init_app()

def _init_app(self):
app = self.app_factory()
app = self.app
app.openapi = self._custom_openapi

@app.get("/", include_in_schema=False)
Expand All @@ -68,38 +70,13 @@ def pin_url():
async def ping():
return {"ping": "pong"}

if self.check_ptype is True:

@app.post("/predict")
async def prediction(
input_data: Union[self.model.ptype, List[self.model.ptype]]
):
if isinstance(input_data, List):
served_data = _batch_data(input_data)
else:
served_data = _prepare_data(input_data)

y = self.model.handler_predict(
served_data, check_ptype=self.check_ptype
)

return {"prediction": y.tolist()}

elif self.check_ptype is False:

@app.post("/predict")
async def prediction(input_data: Request):
y = await input_data.json()

prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)

return {"prediction": prediction.tolist()}

else:
raise ValueError("cannot determine `check_ptype`")
self.vetiver_post(
self.model.handler_predict, "predict", check_ptype=self.check_ptype
)

@app.get("/__docs__", response_class=HTMLResponse, include_in_schema=False)
async def rapidoc():
# save as html html.tpl, .format {spec_url}
return f"""
<!doctype html>
<html>
Expand Down Expand Up @@ -137,9 +114,7 @@ async def rapidoc():

return app

def vetiver_post(
self, endpoint_fx: Callable, endpoint_name: str = "custom_endpoint"
):
def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved
"""Create new POST endpoint that is aware of model input data

Parameters
Expand All @@ -151,29 +126,39 @@ def vetiver_post(

Example
-------
>>> import vetiver
>>> X, y = vetiver.get_mock_data()
>>> model = vetiver.get_mock_model().fit(X, y)
>>> v = vetiver.VetiverModel(model = model, model_name = "model", ptype_data = X)
>>> v_api = vetiver.VetiverAPI(model = v, check_ptype = True)
>>> import vetiver as vt
>>> X, y = vt.get_mock_data()
>>> model = vt.get_mock_model().fit(X, y)
>>> v = vt.VetiverModel(model = model, model_name = "model", ptype_data = X)
>>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
>>> def sum_values(x):
... return x.sum()
>>> v_api.vetiver_post(sum_values, "sums")
"""
if not endpoint_name:
endpoint_name = endpoint_fx.__name__

if self.check_ptype is True:

@self.app.post("/" + endpoint_name)
async def custom_endpoint(input_data: self.model.ptype):
y = _prepare_data(input_data)
new = endpoint_fx(pd.DataFrame(y))
@self.app.post("/" + endpoint_name, name=endpoint_name)
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved
async def custom_endpoint(
input_data: Union[self.model.ptype, List[self.model.ptype]]
):

if isinstance(input_data, List):
served_data = _batch_data(input_data)
else:
served_data = _prepare_data(input_data)

new = endpoint_fx(served_data, **kw)
return {endpoint_name: new.tolist()}

else:

@self.app.post("/" + endpoint_name)
async def custom_endpoint(input_data: Request):
y = await input_data.json()
new = endpoint_fx(pd.DataFrame(y))
served_data = await input_data.json()
new = endpoint_fx(served_data, **kw)

return {endpoint_name: new.tolist()}

Expand Down Expand Up @@ -214,10 +199,11 @@ def _custom_openapi(self):
)
openapi_schema["info"]["x-logo"] = {"url": "../docs/figures/logo.svg"}
self.app.openapi_schema = openapi_schema

return self.app.openapi_schema


def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.DataFrame:
"""Make a prediction from model endpoint

Parameters
Expand Down Expand Up @@ -274,14 +260,14 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
return response_df


def _prepare_data(pred_data):
def _prepare_data(pred_data: list) -> list:
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved
served_data = []
for key, value in pred_data:
served_data.append(value)
return served_data


def _batch_data(pred_data):
def _batch_data(pred_data) -> pd.DataFrame:
columns = pred_data[0].dict().keys()

data = [line.dict() for line in pred_data]
Expand All @@ -290,7 +276,7 @@ def _batch_data(pred_data):
return served_data


def vetiver_endpoint(url="http://127.0.0.1:8000/predict"):
def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:
"""Wrap url where VetiverModel will be deployed

Parameters
Expand Down
65 changes: 46 additions & 19 deletions vetiver/tests/test_add_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from vetiver import mock, VetiverModel, VetiverAPI
import pytest

import numpy as np
import pandas as pd
from fastapi.testclient import TestClient

from vetiver import mock, VetiverModel, VetiverAPI
import vetiver

def _start_application(check_ptype):

@pytest.fixture
def vetiver_model():
np.random.seed(500)
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
Expand All @@ -13,31 +21,50 @@ def _start_application(check_ptype):
description="A regression model for testing purposes",
)

def sum_values(x):
return x.sum()
return v


app = VetiverAPI(v, check_ptype=check_ptype)
def sum_values(x):
return x.sum()


@pytest.fixture
def vetiver_client(vetiver_model): # With check_ptype=True

app = VetiverAPI(vetiver_model, check_ptype=True)
app.vetiver_post(sum_values, "sum")

return app
app.app.root_path = "/sum"
client = TestClient(app.app)

return client

def test_endpoint_adds_ptype():
app = _start_application(check_ptype=True).app

client = TestClient(app)
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": [0]}, response.json()
@pytest.fixture
def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False

app = VetiverAPI(vetiver_model, check_ptype=False)
app.vetiver_post(sum_values, "sum")

app.app.root_path = "/sum"
client = TestClient(app.app)

return client


def test_endpoint_adds_ptype(vetiver_client):

data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client, data=data)

assert isinstance(response, pd.DataFrame)
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()


def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false):

def test_endpoint_adds_no_ptype():
app = _start_application(check_ptype=False).app
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved

client = TestClient(app)
data = [0, 0, 0]
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": [0]}, response.json()
assert response.json() == {"sum": [3, 6, 9]}, response.json()
8 changes: 4 additions & 4 deletions vetiver/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_torch_predict_ptype():
response = client.post("/predict", json=data)

assert response.status_code == 200, response.text
assert response.json() == {"prediction": [-4.060722351074219]}, response.text
assert response.json() == {"predict": [-4.060722351074219]}, response.text
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved


def test_torch_predict_ptype_batch():
Expand All @@ -81,7 +81,7 @@ def test_torch_predict_ptype_batch():

assert response.status_code == 200, response.text
assert response.json() == {
"prediction": [[-4.060722351074219], [-4.060722351074219]]
"predict": [[-4.060722351074219], [-4.060722351074219]]
}, response.text


Expand Down Expand Up @@ -109,7 +109,7 @@ def test_torch_predict_no_ptype_batch():
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {
"prediction": [[-4.060722351074219], [-4.060722351074219]]
"predict": [[-4.060722351074219], [-4.060722351074219]]
}, response.text


Expand All @@ -123,4 +123,4 @@ def test_torch_predict_no_ptype():
data = [[3.3]]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text
assert response.json() == {"predict": [[-4.060722351074219]]}, response.text
8 changes: 4 additions & 4 deletions vetiver/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_predict_endpoint_ptype():
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()
assert response.json() == {"predict": [44.47]}, response.json()


def test_predict_endpoint_ptype_batch():
Expand All @@ -35,7 +35,7 @@ def test_predict_endpoint_ptype_batch():
data = [{"B": 0, "C": 0, "D": 0}, {"B": 0, "C": 0, "D": 0}]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
assert response.json() == {"predict": [44.47, 44.47]}, response.json()


def test_predict_endpoint_ptype_error():
Expand All @@ -52,7 +52,7 @@ def test_predict_endpoint_no_ptype():
data = [{"B": 0, "C": 0, "D": 0}]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()
assert response.json() == {"predict": [44.47]}, response.json()


def test_predict_endpoint_no_ptype_batch():
Expand All @@ -61,7 +61,7 @@ def test_predict_endpoint_no_ptype_batch():
data = [[0, 0, 0], [0, 0, 0]]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
assert response.json() == {"predict": [44.47, 44.47]}, response.json()


def test_predict_endpoint_no_ptype_error():
Expand Down