diff --git a/DEVELOPER.md b/DEVELOPER.md index 01e2f3ae..0b0cb07f 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -76,7 +76,7 @@ python main.py ``` - Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload` + Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload --port 8081` 1. View app at `http://localhost:8081/` diff --git a/langchain_tools_demo/agent.py b/langchain_tools_demo/agent.py index ca2a77ea..d0bdf5e4 100644 --- a/langchain_tools_demo/agent.py +++ b/langchain_tools_demo/agent.py @@ -13,11 +13,10 @@ # limitations under the License. import os -from datetime import date, timedelta -from typing import Dict, Optional +from datetime import date +from typing import Any, Dict, Optional import aiohttp -import dateutil.parser as dparser import google.auth.transport.requests # type: ignore import google.oauth2.id_token # type: ignore from langchain.agents import AgentType, initialize_agent @@ -35,6 +34,8 @@ # aiohttp context connector = None +CLOUD_RUN_AUTHORIZATION_TOKEN = None + # Class for setting up a dedicated llm agent for each individual user class UserAgent: @@ -49,36 +50,6 @@ def __init__(self, client, agent) -> None: user_agents: Dict[str, UserAgent] = {} -def get_id_token(url: str) -> str: - """Helper method to generate ID tokens for authenticated requests""" - # Use Application Default Credentials on Cloud Run - if os.getenv("K_SERVICE"): - auth_req = google.auth.transport.requests.Request() - return google.oauth2.id_token.fetch_id_token(auth_req, url) - else: - # Use gcloud credentials locally - import subprocess - - return ( - subprocess.run( - ["gcloud", "auth", "print-identity-token"], - stdout=subprocess.PIPE, - check=True, - ) - .stdout.strip() - .decode() - ) - - -def get_header() -> Optional[dict]: - if "http://" in BASE_URL: - return None - else: - # Append ID Token to make authenticated requests to Cloud Run services - headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"} - return headers - - async def get_connector(): global connector if connector is None: @@ -91,24 +62,29 @@ async def handle_error_response(response): return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" -async def create_client_session() -> aiohttp.ClientSession: +async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession: + headers = {} + if user_id_token is not None: + # user-specific query authentication + headers["User-Id-Token"] = user_id_token + return aiohttp.ClientSession( connector=await get_connector(), - connector_owner=False, # Prevents connector being closed when closing session - headers=get_header(), + connector_owner=False, + headers=headers, raise_for_status=handle_error_response, ) # Agent -async def init_agent() -> UserAgent: +async def init_agent(user_id_token: Optional[Any]) -> UserAgent: """Load an agent executor with tools and LLM""" print("Initializing agent..") llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") memory = ConversationBufferMemory( memory_key="chat_history", input_key="input", output_key="output" ) - client = await create_client_session() + client = await create_client_session(user_id_token) tools = await initialize_tools(client) agent = initialize_agent( tools, diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py index 211ba803..daffcd90 100644 --- a/langchain_tools_demo/main.py +++ b/langchain_tools_demo/main.py @@ -19,7 +19,7 @@ import uvicorn from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from markdown import markdown @@ -35,7 +35,7 @@ async def lifespan(app: FastAPI): yield # FastAPI app shutdown event close_client_tasks = [ - asyncio.create_task(c.client.close()) for c in user_agents.values() + asyncio.create_task(a.client.close()) for a in user_agents.values() ] asyncio.gather(*close_client_tasks) @@ -50,9 +50,10 @@ async def lifespan(app: FastAPI): BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}] -@app.get("/", response_class=HTMLResponse) +@app.route("/", methods=["GET", "POST"]) async def index(request: Request): """Render the default template.""" + request.session["client_id"] = os.getenv("CLIENT_ID") if "uuid" not in request.session: request.session["uuid"] = str(uuid.uuid4()) request.session["messages"] = BASE_HISTORY @@ -60,20 +61,44 @@ async def index(request: Request): if request.session["uuid"] in user_agents: user_agent = user_agents[request.session["uuid"]] else: - user_agent = await init_agent() + user_agent = await init_agent(user_id_token=None) user_agents[request.session["uuid"]] = user_agent return templates.TemplateResponse( - "index.html", {"request": request, "messages": request.session["messages"]} + "index.html", + { + "request": request, + "messages": request.session["messages"], + "client_id": request.session["client_id"], + }, ) +@app.post("/login/google", response_class=RedirectResponse) +async def login_google( + request: Request, +): + form_data = await request.form() + user_id_token = form_data.get("credential") + if user_id_token is None: + raise HTTPException(status_code=401, detail="No user credentials found") + # create new request session + request.session["uuid"] = str(uuid.uuid4()) + request.session["messages"] = BASE_HISTORY + user_agent = await init_agent(user_id_token) + user_agents[request.session["uuid"]] = user_agent + print("Logged in to Google.") + + # Redirect to source URL + source_url = request.headers["Referer"] + return RedirectResponse(url=source_url) + + @app.post("/chat", response_class=PlainTextResponse) async def chat_handler(request: Request, prompt: str = Body(embed=True)): """Handler for LangChain chat requests""" # Retrieve user prompt if not prompt: raise HTTPException(status_code=400, detail="Error: No user query") - if "uuid" not in request.session: raise HTTPException( status_code=400, detail="Error: Invoke index handler before start chatting" @@ -81,7 +106,6 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["messages"] += [{"role": "user", "content": prompt}] - user_agent = user_agents[request.session["uuid"]] try: # Send prompt to LLM @@ -92,7 +116,6 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Return assistant response return markdown(response["output"]) except Exception as err: - print(err) raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") diff --git a/langchain_tools_demo/static/index.css b/langchain_tools_demo/static/index.css index 4e945144..29f5234f 100644 --- a/langchain_tools_demo/static/index.css +++ b/langchain_tools_demo/static/index.css @@ -41,10 +41,23 @@ body { text-align: center; } -.chat-header span.reset-button { +#g_id_onload, +.g_id_signin { position: absolute; - margin-right: 0px; - font-size: 38px; + top: 20px; + right: 30px; +} + +#resetButton { + font-size: 28px; + cursor: pointer; + position: absolute; + top: 27px; + right: 0px; +} + +#resetButton:hover { + background-color: #c9d4e9; } .chat-wrapper { @@ -84,18 +97,6 @@ div.chat-content>span { border: none; } -#resetButton { - font-size: 35px; - cursor: pointer; - position: absolute; - top: 47px; - right: 40px; -} - -#resetButton:hover { - background-color: #c9d4e9; -} - .chat-bubble { display: block; padding: 50px; diff --git a/langchain_tools_demo/templates/index.html b/langchain_tools_demo/templates/index.html index 7e77699b..f9ca6a25 100644 --- a/langchain_tools_demo/templates/index.html +++ b/langchain_tools_demo/templates/index.html @@ -29,7 +29,7 @@ - + @@ -39,7 +39,20 @@

SFO Airport Assistant

refresh - +
+
+
+
{# Add Chat history #} @@ -68,5 +81,11 @@

SFO Airport Assistant

integrity="sha256-kmHvs0B+OpCW5GVHUNjv9rOmY0IvSIRcf7zGUDTDQM8=" crossorigin="anonymous"> + diff --git a/langchain_tools_demo/tools.py b/langchain_tools_demo/tools.py index 02b32327..e22531fb 100644 --- a/langchain_tools_demo/tools.py +++ b/langchain_tools_demo/tools.py @@ -16,16 +16,37 @@ from typing import Optional import aiohttp -from langchain.tools import StructuredTool, tool +import google.oauth2.id_token # type: ignore +from google.auth.transport.requests import Request # type: ignore +from langchain.tools import StructuredTool from pydantic.v1 import BaseModel, Field BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080") +CREDENTIALS = None def filter_none_values(params: dict) -> dict: return {key: value for key, value in params.items() if value is not None} +def get_id_token(): + global CREDENTIALS + if CREDENTIALS is None: + CREDENTIALS, _ = google.auth.default() + if not CREDENTIALS.valid: + CREDENTIALS.refresh(Request()) + return CREDENTIALS.id_token + + +def get_headers(client: aiohttp.ClientSession): + """Helper method to generate ID tokens for authenticated requests""" + headers = client.headers + if not "http://" in BASE_URL: + # Append ID Token to make authenticated requests to Cloud Run services + headers["Authorization"] = f"Bearer {get_id_token()}" + return headers + + # Tools class AirportSearchInput(BaseModel): country: Optional[str] = Field(description="Country") @@ -33,7 +54,7 @@ class AirportSearchInput(BaseModel): name: Optional[str] = Field(description="Airport name") -async def generate_search_airports(client: aiohttp.ClientSession): +def generate_search_airports(client: aiohttp.ClientSession): async def search_airports(country: str, city: str, name: str): params = { "country": country, @@ -41,7 +62,9 @@ async def search_airports(country: str, city: str, name: str): "name": name, } response = await client.get( - url=f"{BASE_URL}/airports/search", params=filter_none_values(params) + url=f"{BASE_URL}/airports/search", + params=filter_none_values(params), + headers=get_headers(client), ) num = 2 @@ -64,11 +87,12 @@ class FlightNumberInput(BaseModel): flight_number: str = Field(description="1 to 4 digit number") -async def generate_search_flights_by_number(client: aiohttp.ClientSession): +def generate_search_flights_by_number(client: aiohttp.ClientSession): async def search_flights_by_number(airline: str, flight_number: str): response = await client.get( url=f"{BASE_URL}/flights/search", params={"airline": airline, "flight_number": flight_number}, + headers=get_headers(client), ) return await response.json() @@ -84,7 +108,7 @@ class ListFlights(BaseModel): date: Optional[str] = Field(description="Date of flight departure") -async def generate_list_flights(client: aiohttp.ClientSession): +def generate_list_flights(client: aiohttp.ClientSession): async def list_flights( departure_airport: str, arrival_airport: str, @@ -98,6 +122,7 @@ async def list_flights( response = await client.get( url=f"{BASE_URL}/flights/search", params=filter_none_values(params), + headers=get_headers(client), ) num = 2 @@ -119,7 +144,7 @@ class QueryInput(BaseModel): query: str = Field(description="Search query") -async def generate_search_amenities(client: aiohttp.ClientSession): +def generate_search_amenities(client: aiohttp.ClientSession): async def search_amenities(query: str): """ Use this tool to search amenities by name or to recommended airport amenities at SFO. @@ -133,6 +158,7 @@ async def search_amenities(query: str): response = await client.get( url=f"{BASE_URL}/amenities/search", params={"top_k": "5", "query": query}, + headers=get_headers(client), ) response = await response.json() @@ -145,7 +171,7 @@ async def search_amenities(query: str): async def initialize_tools(client: aiohttp.ClientSession): return [ StructuredTool.from_function( - coroutine=await generate_search_airports(client), + coroutine=generate_search_airports(client), name="Search Airport", description=""" Use this tool to list all airports matching search criteria. @@ -174,7 +200,7 @@ async def initialize_tools(client: aiohttp.ClientSession): args_schema=AirportSearchInput, ), StructuredTool.from_function( - coroutine=await generate_search_flights_by_number(client), + coroutine=generate_search_flights_by_number(client), name="Search Flights By Flight Number", description=""" Use this tool to get info for a specific flight. Do NOT use this tool with a flight id. @@ -187,7 +213,7 @@ async def initialize_tools(client: aiohttp.ClientSession): args_schema=FlightNumberInput, ), StructuredTool.from_function( - coroutine=await generate_list_flights(client), + coroutine=generate_list_flights(client), name="List Flights", description=""" Use this tool to list all flights matching search criteria. @@ -216,7 +242,7 @@ async def initialize_tools(client: aiohttp.ClientSession): args_schema=ListFlights, ), StructuredTool.from_function( - coroutine=await generate_search_amenities(client), + coroutine=generate_search_amenities(client), name="Search Amenities", description=""" Use this tool to search amenities by name or to recommended airport amenities at SFO.