Skip to content

Commit

Permalink
feat: Create google sign in button and send id token with request (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
duwenxin99 authored Jan 25, 2024
1 parent 854d44a commit c0a34a7
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 74 deletions.
2 changes: 1 addition & 1 deletion DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
python main.py
```

Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload`
Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload --port 8081`

1. View app at `http://localhost:8081/`

Expand Down
52 changes: 14 additions & 38 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

import os
from datetime import date, timedelta
from typing import Dict, Optional
from datetime import date
from typing import Any, Dict, Optional

import aiohttp
import dateutil.parser as dparser
import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
from langchain.agents import AgentType, initialize_agent
Expand All @@ -35,6 +34,8 @@
# aiohttp context
connector = None

CLOUD_RUN_AUTHORIZATION_TOKEN = None


# Class for setting up a dedicated llm agent for each individual user
class UserAgent:
Expand All @@ -49,36 +50,6 @@ def __init__(self, client, agent) -> None:
user_agents: Dict[str, UserAgent] = {}


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
if os.getenv("K_SERVICE"):
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)
else:
# Use gcloud credentials locally
import subprocess

return (
subprocess.run(
["gcloud", "auth", "print-identity-token"],
stdout=subprocess.PIPE,
check=True,
)
.stdout.strip()
.decode()
)


def get_header() -> Optional[dict]:
if "http://" in BASE_URL:
return None
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
return headers


async def get_connector():
global connector
if connector is None:
Expand All @@ -91,24 +62,29 @@ async def handle_error_response(response):
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"


async def create_client_session() -> aiohttp.ClientSession:
async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession:
headers = {}
if user_id_token is not None:
# user-specific query authentication
headers["User-Id-Token"] = user_id_token

return aiohttp.ClientSession(
connector=await get_connector(),
connector_owner=False, # Prevents connector being closed when closing session
headers=get_header(),
connector_owner=False,
headers=headers,
raise_for_status=handle_error_response,
)


# Agent
async def init_agent() -> UserAgent:
async def init_agent(user_id_token: Optional[Any]) -> UserAgent:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512, model_name="gemini-pro")
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
)
client = await create_client_session()
client = await create_client_session(user_id_token)
tools = await initialize_tools(client)
agent = initialize_agent(
tools,
Expand Down
39 changes: 31 additions & 8 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import uvicorn
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from markdown import markdown
Expand All @@ -35,7 +35,7 @@ async def lifespan(app: FastAPI):
yield
# FastAPI app shutdown event
close_client_tasks = [
asyncio.create_task(c.client.close()) for c in user_agents.values()
asyncio.create_task(a.client.close()) for a in user_agents.values()
]

asyncio.gather(*close_client_tasks)
Expand All @@ -50,38 +50,62 @@ async def lifespan(app: FastAPI):
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]


@app.get("/", response_class=HTMLResponse)
@app.route("/", methods=["GET", "POST"])
async def index(request: Request):
"""Render the default template."""
request.session["client_id"] = os.getenv("CLIENT_ID")
if "uuid" not in request.session:
request.session["uuid"] = str(uuid.uuid4())
request.session["messages"] = BASE_HISTORY
# Agent setup
if request.session["uuid"] in user_agents:
user_agent = user_agents[request.session["uuid"]]
else:
user_agent = await init_agent()
user_agent = await init_agent(user_id_token=None)
user_agents[request.session["uuid"]] = user_agent
return templates.TemplateResponse(
"index.html", {"request": request, "messages": request.session["messages"]}
"index.html",
{
"request": request,
"messages": request.session["messages"],
"client_id": request.session["client_id"],
},
)


@app.post("/login/google", response_class=RedirectResponse)
async def login_google(
request: Request,
):
form_data = await request.form()
user_id_token = form_data.get("credential")
if user_id_token is None:
raise HTTPException(status_code=401, detail="No user credentials found")
# create new request session
request.session["uuid"] = str(uuid.uuid4())
request.session["messages"] = BASE_HISTORY
user_agent = await init_agent(user_id_token)
user_agents[request.session["uuid"]] = user_agent
print("Logged in to Google.")

# Redirect to source URL
source_url = request.headers["Referer"]
return RedirectResponse(url=source_url)


@app.post("/chat", response_class=PlainTextResponse)
async def chat_handler(request: Request, prompt: str = Body(embed=True)):
"""Handler for LangChain chat requests"""
# Retrieve user prompt
if not prompt:
raise HTTPException(status_code=400, detail="Error: No user query")

if "uuid" not in request.session:
raise HTTPException(
status_code=400, detail="Error: Invoke index handler before start chatting"
)

# Add user message to chat history
request.session["messages"] += [{"role": "user", "content": prompt}]

user_agent = user_agents[request.session["uuid"]]
try:
# Send prompt to LLM
Expand All @@ -92,7 +116,6 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
# Return assistant response
return markdown(response["output"])
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")


Expand Down
31 changes: 16 additions & 15 deletions langchain_tools_demo/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,23 @@ body {
text-align: center;
}

.chat-header span.reset-button {
#g_id_onload,
.g_id_signin {
position: absolute;
margin-right: 0px;
font-size: 38px;
top: 20px;
right: 30px;
}

#resetButton {
font-size: 28px;
cursor: pointer;
position: absolute;
top: 27px;
right: 0px;
}

#resetButton:hover {
background-color: #c9d4e9;
}

.chat-wrapper {
Expand Down Expand Up @@ -84,18 +97,6 @@ div.chat-content>span {
border: none;
}

#resetButton {
font-size: 35px;
cursor: pointer;
position: absolute;
top: 47px;
right: 40px;
}

#resetButton:hover {
background-color: #c9d4e9;
}

.chat-bubble {
display: block;
padding: 50px;
Expand Down
23 changes: 21 additions & 2 deletions langchain_tools_demo/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/material.indigo-pink.min.css">
<script defer src="https://code.getmdl.io/1.3.0/material.min.js"></script>
<link rel="stylesheet" href="/static/index.css">

<script src="https://accounts.google.com/gsi/client" async defer></script>
</head>

<body>
Expand All @@ -39,7 +39,20 @@
<h1>SFO Airport Assistant</h1>
<span class="material-symbols-outlined" id="resetButton">refresh</span>
</div>

<div id="g_id_onload"
data-context="signin"
data-ux_mode="popup"
data-auto_prompt="false">
</div>
<div class="g_id_signin"
data-type="standard"
data-shape="rectangular"
data-theme="outline"
data-text="signin_with"
data-size="large"
data-logo_alignment="left"
data-onsuccess="onSignIn">
</div>
<div class="chat-wrapper">
<div class="chat-content">
{# Add Chat history #}
Expand Down Expand Up @@ -68,5 +81,11 @@ <h1>SFO Airport Assistant</h1>
integrity="sha256-kmHvs0B+OpCW5GVHUNjv9rOmY0IvSIRcf7zGUDTDQM8="
crossorigin="anonymous"></script>
<script src="/static/index.js"></script>
<script>
document.getElementById('g_id_onload').setAttribute('data-client_id', '{{ client_id }}');
var currentUrl = window.location.href;
var loginUri = currentUrl + 'login/google';
document.getElementById('g_id_onload').setAttribute('data-login_uri', loginUri);
</script>

</html>
Loading

0 comments on commit c0a34a7

Please sign in to comment.