Skip to content

Commit

Permalink
Merge branch 'main' into renovate/pydantic-2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
averikitsch authored Jan 29, 2024
2 parents 2a3e309 + 2584154 commit 9a0ce93
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 153 deletions.
2 changes: 1 addition & 1 deletion DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
python main.py
```

Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload`
Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload --port 8081`

1. View app at `http://localhost:8081/`

Expand Down
64 changes: 23 additions & 41 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
# limitations under the License.

import os
from datetime import date, timedelta
from typing import Dict, Optional
from datetime import date
from typing import Any, Dict, Optional

import aiohttp
import dateutil.parser as dparser
import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts.chat import ChatPromptTemplate
from langchain_core import messages

from tools import initialize_tools

Expand All @@ -35,6 +35,8 @@
# aiohttp context
connector = None

CLOUD_RUN_AUTHORIZATION_TOKEN = None


# Class for setting up a dedicated llm agent for each individual user
class UserAgent:
Expand All @@ -49,36 +51,6 @@ def __init__(self, client, agent) -> None:
user_agents: Dict[str, UserAgent] = {}


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
# Use Application Default Credentials on Cloud Run
if os.getenv("K_SERVICE"):
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)
else:
# Use gcloud credentials locally
import subprocess

return (
subprocess.run(
["gcloud", "auth", "print-identity-token"],
stdout=subprocess.PIPE,
check=True,
)
.stdout.strip()
.decode()
)


def get_header() -> Optional[dict]:
if "http://" in BASE_URL:
return None
else:
# Append ID Token to make authenticated requests to Cloud Run services
headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"}
return headers


async def get_connector():
global connector
if connector is None:
Expand All @@ -91,24 +63,34 @@ async def handle_error_response(response):
return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}"


async def create_client_session() -> aiohttp.ClientSession:
async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession:
headers = {}
if user_id_token is not None:
# user-specific query authentication
headers["User-Id-Token"] = user_id_token

return aiohttp.ClientSession(
connector=await get_connector(),
connector_owner=False, # Prevents connector being closed when closing session
headers=get_header(),
raise_for_status=handle_error_response,
connector_owner=False,
headers=headers,
raise_for_status=True,
)


# Agent
async def init_agent() -> UserAgent:
async def init_agent(
user_id_token: Optional[Any], history: list[messages.BaseMessage]
) -> UserAgent:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512, model_name="gemini-pro")
memory = ConversationBufferMemory(
memory_key="chat_history", input_key="input", output_key="output"
chat_memory=ChatMessageHistory(messages=history),
memory_key="chat_history",
input_key="input",
output_key="output",
)
client = await create_client_session()
client = await create_client_session(user_id_token)
tools = await initialize_tools(client)
agent = initialize_agent(
tools,
Expand Down
87 changes: 64 additions & 23 deletions langchain_tools_demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,21 @@
import os
import uuid
from contextlib import asynccontextmanager
from typing import Any

import uvicorn
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.responses import PlainTextResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
from markdown import markdown
from starlette.middleware.sessions import SessionMiddleware

Expand All @@ -35,7 +44,7 @@ async def lifespan(app: FastAPI):
yield
# FastAPI app shutdown event
close_client_tasks = [
asyncio.create_task(c.client.close()) for c in user_agents.values()
asyncio.create_task(a.client.close()) for a in user_agents.values()
]

asyncio.gather(*close_client_tasks)
Expand All @@ -47,61 +56,93 @@ async def lifespan(app: FastAPI):
# TODO: set secret_key for production
app.add_middleware(SessionMiddleware, secret_key="SECRET_KEY")
templates = Jinja2Templates(directory="templates")
BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}]
BASE_HISTORY: list[BaseMessage] = [
AIMessage(content="I am an SFO Airport Assistant, ready to assist you.")
]


@app.get("/", response_class=HTMLResponse)
@app.route("/", methods=["GET", "POST"])
async def index(request: Request):
"""Render the default template."""
if "uuid" not in request.session:
request.session["uuid"] = str(uuid.uuid4())
request.session["messages"] = BASE_HISTORY
# Agent setup
if request.session["uuid"] in user_agents:
user_agent = user_agents[request.session["uuid"]]
else:
user_agent = await init_agent()
user_agents[request.session["uuid"]] = user_agent
agent = await get_agent(request.session)
print(request.session["history"])
return templates.TemplateResponse(
"index.html", {"request": request, "messages": request.session["messages"]}
"index.html",
{
"request": request,
"messages": request.session["history"],
"client_id": request.session.get("client_id"),
},
)


@app.post("/login/google", response_class=RedirectResponse)
async def login_google(
request: Request,
):
form_data = await request.form()
user_id_token = form_data.get("credential")
if user_id_token is None:
raise HTTPException(status_code=401, detail="No user credentials found")
# create new request session
_ = await get_agent(request.session)
print("Logged in to Google.")

# Redirect to source URL
source_url = request.headers["Referer"]
return RedirectResponse(url=source_url)


@app.post("/chat", response_class=PlainTextResponse)
async def chat_handler(request: Request, prompt: str = Body(embed=True)):
"""Handler for LangChain chat requests"""
# Retrieve user prompt
if not prompt:
raise HTTPException(status_code=400, detail="Error: No user query")

if "uuid" not in request.session:
raise HTTPException(
status_code=400, detail="Error: Invoke index handler before start chatting"
)

# Add user message to chat history
request.session["messages"] += [{"role": "user", "content": prompt}]

user_agent = user_agents[request.session["uuid"]]
request.session["history"].append(message_to_dict(HumanMessage(content=prompt)))
user_agent = await get_agent(request.session)
try:
# Send prompt to LLM
response = await user_agent.agent.ainvoke({"input": prompt})
request.session["messages"] += [
{"role": "assistant", "content": response["output"]}
]
# Return assistant response
request.session["history"].append(
message_to_dict(AIMessage(content=response["output"]))
)
return markdown(response["output"])
except Exception as err:
print(err)
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")


async def get_agent(session: dict[str, Any]):
global user_agents
if "uuid" not in session:
session["uuid"] = str(uuid.uuid4())
id = session["uuid"]
if "history" not in session:
session["history"] = messages_to_dict(BASE_HISTORY)
if uuid not in user_agents:
user_agents[id] = await init_agent(
session["uuid"], messages_from_dict(session["history"])
)
return user_agents[id]


@app.post("/reset")
async def reset(request: Request):
"""Reset agent"""
global user_agents
uuid = request.session["uuid"]

if "uuid" not in request.session:
raise HTTPException(status_code=400, detail=f"No session to reset.")

uuid = request.session["uuid"]
global user_agents
if uuid not in user_agents.keys():
raise HTTPException(status_code=500, detail=f"Current agent not found")

Expand Down
4 changes: 2 additions & 2 deletions langchain_tools_demo/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
black==23.12.0
black==24.1.0
pytest==7.4.4
mypy==1.7.1
isort==5.13.2
types-requests==2.31.0.20231231
types-requests==2.31.0.20240106
types-python-dateutil==2.8.19.20240106
9 changes: 5 additions & 4 deletions langchain_tools_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
fastapi==0.108.0
fastapi==0.109.0
google-cloud-aiplatform==1.38.1
google-auth==2.26.1
google-auth==2.26.2
itsdangerous==2.1.2
jinja2==3.1.3
langchain==0.0.354
markdown==3.5.1
types-Markdown==3.5.0.3
markdown==3.5.2
types-Markdown==3.5.0.20240106
uvicorn[standard]==0.25.0
python-multipart==0.0.6
35 changes: 18 additions & 17 deletions langchain_tools_demo/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,23 @@ body {
text-align: center;
}

.chat-header span.reset-button {
#g_id_onload,
.g_id_signin {
position: absolute;
margin-right: 0px;
font-size: 38px;
top: 20px;
right: 30px;
}

#resetButton {
font-size: 28px;
cursor: pointer;
position: absolute;
top: 27px;
right: 0px;
}

#resetButton:hover {
background-color: #c9d4e9;
}

.chat-wrapper {
Expand Down Expand Up @@ -84,18 +97,6 @@ div.chat-content>span {
border: none;
}

#resetButton {
font-size: 35px;
cursor: pointer;
position: absolute;
top: 47px;
right: 40px;
}

#resetButton:hover {
background-color: #c9d4e9;
}

.chat-bubble {
display: block;
padding: 50px;
Expand All @@ -107,7 +108,7 @@ div.chat-content>span {
padding: 0;
}

div.chat-wrapper div.chat-content span.assistant {
div.chat-wrapper div.chat-content span.ai {
position: relative;
width: 70%;
height: auto;
Expand All @@ -118,7 +119,7 @@ div.chat-wrapper div.chat-content span.assistant {
border-radius: 2px 15px 15px 15px;
}

div.chat-wrapper div.chat-content span.user {
div.chat-wrapper div.chat-content span.human {
position: relative;
float: right;
width: 70%;
Expand Down
8 changes: 4 additions & 4 deletions langchain_tools_demo/static/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ $('#resetButton').click(async (e) => {
async function submitMessage() {
let msg = $('.chat-bar input').val();
// Add message to UI
log("user", msg)
log("human", msg)
// Clear message
$('.chat-bar input').val('');
$('.mdl-progress').show()
Expand All @@ -43,15 +43,15 @@ async function submitMessage() {
let answer = await askQuestion(msg);
$('.mdl-progress').hide();
// Add response to UI
log("assistant", answer)
log("ai", answer)
} catch (err) {
window.alert(`Error when submitting question: ${err}`);
}
}

// Send request to backend
async function askQuestion(prompt) {
const response = await fetch('/chat', {
const response = await fetch('chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
Expand All @@ -68,7 +68,7 @@ async function askQuestion(prompt) {
}

async function reset() {
await fetch('/reset', {
await fetch('reset', {
method: 'POST',
}).then(()=>{
window.location.reload()
Expand Down
Loading

0 comments on commit 9a0ce93

Please sign in to comment.