diff --git a/DEVELOPER.md b/DEVELOPER.md index 01e2f3ae..0b0cb07f 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -76,7 +76,7 @@ python main.py ``` - Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload` + Note: for hot reloading of the app use: `uvicorn main:app --host 0.0.0.0 --reload --port 8081` 1. View app at `http://localhost:8081/` diff --git a/langchain_tools_demo/agent.py b/langchain_tools_demo/agent.py index ca2a77ea..d0bdf5e4 100644 --- a/langchain_tools_demo/agent.py +++ b/langchain_tools_demo/agent.py @@ -13,11 +13,10 @@ # limitations under the License. import os -from datetime import date, timedelta -from typing import Dict, Optional +from datetime import date +from typing import Any, Dict, Optional import aiohttp -import dateutil.parser as dparser import google.auth.transport.requests # type: ignore import google.oauth2.id_token # type: ignore from langchain.agents import AgentType, initialize_agent @@ -35,6 +34,8 @@ # aiohttp context connector = None +CLOUD_RUN_AUTHORIZATION_TOKEN = None + # Class for setting up a dedicated llm agent for each individual user class UserAgent: @@ -49,36 +50,6 @@ def __init__(self, client, agent) -> None: user_agents: Dict[str, UserAgent] = {} -def get_id_token(url: str) -> str: - """Helper method to generate ID tokens for authenticated requests""" - # Use Application Default Credentials on Cloud Run - if os.getenv("K_SERVICE"): - auth_req = google.auth.transport.requests.Request() - return google.oauth2.id_token.fetch_id_token(auth_req, url) - else: - # Use gcloud credentials locally - import subprocess - - return ( - subprocess.run( - ["gcloud", "auth", "print-identity-token"], - stdout=subprocess.PIPE, - check=True, - ) - .stdout.strip() - .decode() - ) - - -def get_header() -> Optional[dict]: - if "http://" in BASE_URL: - return None - else: - # Append ID Token to make authenticated requests to Cloud Run services - headers = {"Authorization": f"Bearer {get_id_token(BASE_URL)}"} - return headers - - async def get_connector(): global connector if connector is None: @@ -91,24 +62,29 @@ async def handle_error_response(response): return f"Error sending {response.method} request to {str(response.url)}): {await response.text()}" -async def create_client_session() -> aiohttp.ClientSession: +async def create_client_session(user_id_token: Optional[str]) -> aiohttp.ClientSession: + headers = {} + if user_id_token is not None: + # user-specific query authentication + headers["User-Id-Token"] = user_id_token + return aiohttp.ClientSession( connector=await get_connector(), - connector_owner=False, # Prevents connector being closed when closing session - headers=get_header(), + connector_owner=False, + headers=headers, raise_for_status=handle_error_response, ) # Agent -async def init_agent() -> UserAgent: +async def init_agent(user_id_token: Optional[Any]) -> UserAgent: """Load an agent executor with tools and LLM""" print("Initializing agent..") llm = VertexAI(max_output_tokens=512, model_name="gemini-pro") memory = ConversationBufferMemory( memory_key="chat_history", input_key="input", output_key="output" ) - client = await create_client_session() + client = await create_client_session(user_id_token) tools = await initialize_tools(client) agent = initialize_agent( tools, diff --git a/langchain_tools_demo/main.py b/langchain_tools_demo/main.py index 211ba803..daffcd90 100644 --- a/langchain_tools_demo/main.py +++ b/langchain_tools_demo/main.py @@ -19,7 +19,7 @@ import uvicorn from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from markdown import markdown @@ -35,7 +35,7 @@ async def lifespan(app: FastAPI): yield # FastAPI app shutdown event close_client_tasks = [ - asyncio.create_task(c.client.close()) for c in user_agents.values() + asyncio.create_task(a.client.close()) for a in user_agents.values() ] asyncio.gather(*close_client_tasks) @@ -50,9 +50,10 @@ async def lifespan(app: FastAPI): BASE_HISTORY = [{"role": "assistant", "content": "How can I help you?"}] -@app.get("/", response_class=HTMLResponse) +@app.route("/", methods=["GET", "POST"]) async def index(request: Request): """Render the default template.""" + request.session["client_id"] = os.getenv("CLIENT_ID") if "uuid" not in request.session: request.session["uuid"] = str(uuid.uuid4()) request.session["messages"] = BASE_HISTORY @@ -60,20 +61,44 @@ async def index(request: Request): if request.session["uuid"] in user_agents: user_agent = user_agents[request.session["uuid"]] else: - user_agent = await init_agent() + user_agent = await init_agent(user_id_token=None) user_agents[request.session["uuid"]] = user_agent return templates.TemplateResponse( - "index.html", {"request": request, "messages": request.session["messages"]} + "index.html", + { + "request": request, + "messages": request.session["messages"], + "client_id": request.session["client_id"], + }, ) +@app.post("/login/google", response_class=RedirectResponse) +async def login_google( + request: Request, +): + form_data = await request.form() + user_id_token = form_data.get("credential") + if user_id_token is None: + raise HTTPException(status_code=401, detail="No user credentials found") + # create new request session + request.session["uuid"] = str(uuid.uuid4()) + request.session["messages"] = BASE_HISTORY + user_agent = await init_agent(user_id_token) + user_agents[request.session["uuid"]] = user_agent + print("Logged in to Google.") + + # Redirect to source URL + source_url = request.headers["Referer"] + return RedirectResponse(url=source_url) + + @app.post("/chat", response_class=PlainTextResponse) async def chat_handler(request: Request, prompt: str = Body(embed=True)): """Handler for LangChain chat requests""" # Retrieve user prompt if not prompt: raise HTTPException(status_code=400, detail="Error: No user query") - if "uuid" not in request.session: raise HTTPException( status_code=400, detail="Error: Invoke index handler before start chatting" @@ -81,7 +106,6 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Add user message to chat history request.session["messages"] += [{"role": "user", "content": prompt}] - user_agent = user_agents[request.session["uuid"]] try: # Send prompt to LLM @@ -92,7 +116,6 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)): # Return assistant response return markdown(response["output"]) except Exception as err: - print(err) raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}") diff --git a/langchain_tools_demo/static/index.css b/langchain_tools_demo/static/index.css index 4e945144..29f5234f 100644 --- a/langchain_tools_demo/static/index.css +++ b/langchain_tools_demo/static/index.css @@ -41,10 +41,23 @@ body { text-align: center; } -.chat-header span.reset-button { +#g_id_onload, +.g_id_signin { position: absolute; - margin-right: 0px; - font-size: 38px; + top: 20px; + right: 30px; +} + +#resetButton { + font-size: 28px; + cursor: pointer; + position: absolute; + top: 27px; + right: 0px; +} + +#resetButton:hover { + background-color: #c9d4e9; } .chat-wrapper { @@ -84,18 +97,6 @@ div.chat-content>span { border: none; } -#resetButton { - font-size: 35px; - cursor: pointer; - position: absolute; - top: 47px; - right: 40px; -} - -#resetButton:hover { - background-color: #c9d4e9; -} - .chat-bubble { display: block; padding: 50px; diff --git a/langchain_tools_demo/templates/index.html b/langchain_tools_demo/templates/index.html index 7e77699b..f9ca6a25 100644 --- a/langchain_tools_demo/templates/index.html +++ b/langchain_tools_demo/templates/index.html @@ -29,7 +29,7 @@ - +
@@ -39,7 +39,20 @@