Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make vetiver_post extensible for all endpoints #130

Merged
merged 9 commits into from
Nov 29, 2022
104 changes: 45 additions & 59 deletions vetiver/server.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.openapi.utils import get_openapi
from fastapi import testclient
import httpx
from typing import Callable, List, Union
from urllib.parse import urljoin

import uvicorn
import requests
import httpx
import pandas as pd
from typing import Callable, Union, List
import requests
import uvicorn
from fastapi import FastAPI, Request, testclient
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse

from .vetiver_model import VetiverModel
from .utils import _jupyter_nb
from .vetiver_model import VetiverModel


class VetiverAPI:
Expand Down Expand Up @@ -45,10 +45,12 @@ def __init__(
self.model = model
self.check_ptype = check_ptype
self.app_factory = app_factory
self.app = self._init_app()
self.app = app_factory()

self._init_app()

def _init_app(self):
app = self.app_factory()
app = self.app
app.openapi = self._custom_openapi

@app.get("/", include_in_schema=False)
Expand All @@ -68,38 +70,13 @@ def pin_url():
async def ping():
return {"ping": "pong"}

if self.check_ptype is True:

@app.post("/predict")
async def prediction(
input_data: Union[self.model.ptype, List[self.model.ptype]]
):
if isinstance(input_data, List):
served_data = _batch_data(input_data)
else:
served_data = _prepare_data(input_data)

y = self.model.handler_predict(
served_data, check_ptype=self.check_ptype
)

return {"prediction": y.tolist()}

elif self.check_ptype is False:

@app.post("/predict")
async def prediction(input_data: Request):
y = await input_data.json()

prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)

return {"prediction": prediction.tolist()}

else:
raise ValueError("cannot determine `check_ptype`")
self.vetiver_post(
self.model.handler_predict, "predict", check_ptype=self.check_ptype
)

@app.get("/__docs__", response_class=HTMLResponse, include_in_schema=False)
async def rapidoc():
# save as html html.tpl, .format {spec_url}
return f"""
<!doctype html>
<html>
Expand Down Expand Up @@ -137,9 +114,7 @@ async def rapidoc():

return app

def vetiver_post(
self, endpoint_fx: Callable, endpoint_name: str = "custom_endpoint"
):
def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved
"""Create new POST endpoint that is aware of model input data

Parameters
Expand All @@ -151,29 +126,39 @@ def vetiver_post(

Example
-------
>>> import vetiver
>>> X, y = vetiver.get_mock_data()
>>> model = vetiver.get_mock_model().fit(X, y)
>>> v = vetiver.VetiverModel(model = model, model_name = "model", ptype_data = X)
>>> v_api = vetiver.VetiverAPI(model = v, check_ptype = True)
>>> import vetiver as vt
>>> X, y = vt.get_mock_data()
>>> model = vt.get_mock_model().fit(X, y)
>>> v = vt.VetiverModel(model = model, model_name = "model", ptype_data = X)
>>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
>>> def sum_values(x):
... return x.sum()
>>> v_api.vetiver_post(sum_values, "sums")
"""
if not endpoint_name:
endpoint_name = endpoint_fx.__name__

if self.check_ptype is True:

@self.app.post("/" + endpoint_name)
async def custom_endpoint(input_data: self.model.ptype):
y = _prepare_data(input_data)
new = endpoint_fx(pd.DataFrame(y))
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
async def custom_endpoint(
input_data: Union[self.model.ptype, List[self.model.ptype]]
):

if isinstance(input_data, List):
served_data = _batch_data(input_data)
else:
served_data = _prepare_data(input_data)

new = endpoint_fx(served_data, **kw)
return {endpoint_name: new.tolist()}

else:

@self.app.post("/" + endpoint_name)
@self.app.post(urljoin("/", endpoint_name))
async def custom_endpoint(input_data: Request):
y = await input_data.json()
new = endpoint_fx(pd.DataFrame(y))
served_data = await input_data.json()
new = endpoint_fx(served_data, **kw)

return {endpoint_name: new.tolist()}

Expand Down Expand Up @@ -214,10 +199,11 @@ def _custom_openapi(self):
)
openapi_schema["info"]["x-logo"] = {"url": "../docs/figures/logo.svg"}
self.app.openapi_schema = openapi_schema

return self.app.openapi_schema


def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.DataFrame:
"""Make a prediction from model endpoint

Parameters
Expand Down Expand Up @@ -274,14 +260,14 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
return response_df


def _prepare_data(pred_data):
def _prepare_data(pred_data: list) -> list:
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved
served_data = []
for key, value in pred_data:
served_data.append(value)
return served_data


def _batch_data(pred_data):
def _batch_data(pred_data) -> pd.DataFrame:
columns = pred_data[0].dict().keys()

data = [line.dict() for line in pred_data]
Expand All @@ -290,7 +276,7 @@ def _batch_data(pred_data):
return served_data


def vetiver_endpoint(url="http://127.0.0.1:8000/predict"):
def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:
"""Wrap url where VetiverModel will be deployed

Parameters
Expand Down
72 changes: 52 additions & 20 deletions vetiver/tests/test_add_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from vetiver import mock, VetiverModel, VetiverAPI
import pytest

import numpy as np
import pandas as pd
from fastapi.testclient import TestClient

from vetiver import mock, VetiverModel, VetiverAPI
import vetiver

def _start_application(check_ptype):

@pytest.fixture
def vetiver_model():
np.random.seed(500)
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
Expand All @@ -13,31 +21,55 @@ def _start_application(check_ptype):
description="A regression model for testing purposes",
)

def sum_values(x):
return x.sum()
return v


def sum_values(x):
return x.sum()


def sum_dict(x):
x = pd.DataFrame(x)
return x.sum()


app = VetiverAPI(v, check_ptype=check_ptype)
@pytest.fixture
def vetiver_client(vetiver_model): # With check_ptype=True

app = VetiverAPI(vetiver_model, check_ptype=True)
app.vetiver_post(sum_values, "sum")

return app
app.app.root_path = "/sum"
client = TestClient(app.app)

return client


@pytest.fixture
def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False

app = VetiverAPI(vetiver_model, check_ptype=False)
app.vetiver_post(sum_dict, "sum")

app.app.root_path = "/sum"
client = TestClient(app.app)

return client


def test_endpoint_adds_ptype(vetiver_client):

data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client, data=data)

def test_endpoint_adds_ptype():
app = _start_application(check_ptype=True).app
assert isinstance(response, pd.DataFrame)
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()

client = TestClient(app)
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": [0]}, response.json()

def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false):

def test_endpoint_adds_no_ptype():
app = _start_application(check_ptype=False).app
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved

client = TestClient(app)
data = [0, 0, 0]
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": [0]}, response.json()
assert isinstance(response, pd.DataFrame)
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()
8 changes: 4 additions & 4 deletions vetiver/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_torch_predict_ptype():
response = client.post("/predict", json=data)

assert response.status_code == 200, response.text
assert response.json() == {"prediction": [-4.060722351074219]}, response.text
assert response.json() == {"predict": [-4.060722351074219]}, response.text
isabelizimm marked this conversation as resolved.
Show resolved Hide resolved


def test_torch_predict_ptype_batch():
Expand All @@ -81,7 +81,7 @@ def test_torch_predict_ptype_batch():

assert response.status_code == 200, response.text
assert response.json() == {
"prediction": [[-4.060722351074219], [-4.060722351074219]]
"predict": [[-4.060722351074219], [-4.060722351074219]]
}, response.text


Expand Down Expand Up @@ -109,7 +109,7 @@ def test_torch_predict_no_ptype_batch():
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {
"prediction": [[-4.060722351074219], [-4.060722351074219]]
"predict": [[-4.060722351074219], [-4.060722351074219]]
}, response.text


Expand All @@ -123,4 +123,4 @@ def test_torch_predict_no_ptype():
data = [[3.3]]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text
assert response.json() == {"predict": [[-4.060722351074219]]}, response.text
8 changes: 4 additions & 4 deletions vetiver/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_predict_endpoint_ptype():
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()
assert response.json() == {"predict": [44.47]}, response.json()


def test_predict_endpoint_ptype_batch():
Expand All @@ -35,7 +35,7 @@ def test_predict_endpoint_ptype_batch():
data = [{"B": 0, "C": 0, "D": 0}, {"B": 0, "C": 0, "D": 0}]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
assert response.json() == {"predict": [44.47, 44.47]}, response.json()


def test_predict_endpoint_ptype_error():
Expand All @@ -52,7 +52,7 @@ def test_predict_endpoint_no_ptype():
data = [{"B": 0, "C": 0, "D": 0}]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()
assert response.json() == {"predict": [44.47]}, response.json()


def test_predict_endpoint_no_ptype_batch():
Expand All @@ -61,7 +61,7 @@ def test_predict_endpoint_no_ptype_batch():
data = [[0, 0, 0], [0, 0, 0]]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
assert response.json() == {"predict": [44.47, 44.47]}, response.json()


def test_predict_endpoint_no_ptype_error():
Expand Down