Skip to content

Commit

Permalink
update app with drift detection endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
alex000kim committed Nov 27, 2023
1 parent f6d4b83 commit b9e18b1
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@

from typing import List

import datetime
import json
import os
import warnings

import pandas as pd
from fastapi import Body, FastAPI, Request
from alibi_detect.saving import load_detector
from fastapi import BackgroundTasks, FastAPI, Request, Body
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from joblib import load
from pydantic import BaseModel

from sqlalchemy import create_engine
from utils.load_params import load_params

app = FastAPI()
Expand All @@ -27,9 +33,14 @@
allow_headers=["*"],
)

DATABASE_URL = os.environ['DATABASE_URL'].replace('postgres://', 'postgresql://')

params = load_params(params_path='params.yaml')
model_path = params.train.model_path
feat_cols = params.base.feat_cols
min_batch_size = params.drift_detect.min_batch_size

cd = load_detector(Path('models')/'drift_detector')
model = load(filename=model_path)

class Customer(BaseModel):
Expand All @@ -46,7 +57,7 @@ class Request(BaseModel):
data: List[Customer]

@app.post("/predict")
async def predict(info: Request = Body(..., example={
async def predict(background_tasks: BackgroundTasks, info: Request = Body(..., example={
"data": [
{
"CreditScore": 619,
Expand All @@ -72,10 +83,44 @@ async def predict(info: Request = Body(..., example={
})):
json_list = json.loads(info.json())
data = json_list['data']
try:
background_tasks.add_task(collect_batch, json_list)
except:
warnings.warn("Unable to process batch data for drift detection")
input_data = pd.DataFrame(data)
probs = model.predict_proba(input_data)[:,0]
probs = probs.tolist()
return probs

@app.get("/drift_data")
async def get_drift_data():
engine = create_engine(DATABASE_URL)
with engine.connect() as conn:
sql_query = "SELECT * FROM p_val_table"
df_p_val = pd.read_sql(sql_query, con=conn)
engine.dispose()
parsed = json.loads(df_p_val.to_json())
return json.dumps(parsed)

def collect_batch(json_list, batch_size_thres = min_batch_size, batch = []):
data = json_list['data']
for req_json in data:
batch.append(req_json)
L = len(batch)
if L >= batch_size_thres:
X = pd.DataFrame.from_records(batch)
preds = cd.predict(X)
p_val = preds['data']['p_val']
now = datetime.datetime.now()
data = [[now] + p_val.tolist()]
columns = ['time'] + feat_cols
df_p_val = pd.DataFrame(data=data, columns=columns)
print('Writing to database')
engine = create_engine(DATABASE_URL)
with engine.connect() as conn:
df_p_val.to_sql('p_val_table', con=conn, if_exists='append', index=False)
engine.dispose()
batch.clear()

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

1 comment on commit b9e18b1

@alex000kim
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics

Path f1 roc_auc
metrics.json 0.52953 0.81388

Feature Importances

feature importance
num__Age 0.19
num__NumOfProducts 0.17
num__IsActiveMember 0.08
num__Balance 0.06
num__CreditScore 0.03
num__HasCrCard 0.00
num__Tenure -0.00
num__EstimatedSalary -0.00

Confusion Matrix

Please sign in to comment.