Skip to content

Commit

Permalink
feat: Update prompt and tools (#34)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuan <[email protected]>
  • Loading branch information
averikitsch and Yuan325 authored Nov 6, 2023
1 parent 0dee8dd commit 499f6a7
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 150 deletions.
19 changes: 16 additions & 3 deletions extension_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Optional

from fastapi import APIRouter, Request
from fastapi import APIRouter, HTTPException, Request
from langchain.embeddings.base import Embeddings

import datastore
Expand Down Expand Up @@ -49,7 +49,7 @@ async def amenities_search(query: str, top_k: int, request: Request):
embed_service: Embeddings = request.app.state.embed_service
query_embedding = embed_service.embed_query(query)

results = await ds.amenities_search(query_embedding, 0.3, top_k)
results = await ds.amenities_search(query_embedding, 0.5, top_k)
return results


Expand All @@ -65,7 +65,20 @@ async def search_flights(
request: Request,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
date: Optional[str] = None,
airline: Optional[str] = None,
flight_number: Optional[str] = None,
):
ds: datastore.Client = request.app.state.datastore
flights = await ds.search_flights(departure_airport, arrival_airport)
if date and (arrival_airport or departure_airport):
flights = await ds.search_flights_by_airports(
date, departure_airport, arrival_airport
)
elif airline and flight_number:
flights = await ds.search_flights_by_number(airline, flight_number)
else:
raise HTTPException(
status_code=422,
detail="Request requires query params: arrival_airport, departure_airport, date, or both airline and flight_number",
)
return flights
11 changes: 10 additions & 1 deletion extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,17 @@ async def get_flight(self, flight_id: int) -> Optional[list[models.Flight]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def search_flights(
async def search_flights_by_number(
self,
airline: str,
flight_number: str,
) -> Optional[list[models.Flight]]:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def search_flights_by_airports(
self,
date,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> Optional[list[models.Flight]]:
Expand Down
26 changes: 24 additions & 2 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
from datetime import datetime
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Dict, Literal, Optional

Expand Down Expand Up @@ -251,8 +252,27 @@ async def get_flight(self, flight_id: int) -> Optional[list[models.Flight]]:
flights = [models.Flight.model_validate(dict(r)) for r in results]
return flights

async def search_flights(
async def search_flights_by_number(
self,
airline: str,
number: str,
) -> Optional[list[models.Flight]]:
results = await self.__pool.fetch(
"""
SELECT * FROM flights
WHERE airline = $1
AND flight_number = $2;
""",
airline,
number,
timeout=10,
)
flights = [models.Flight.model_validate(dict(r)) for r in results]
return flights

async def search_flights_by_airports(
self,
date: str,
departure_airport: Optional[str] = None,
arrival_airport: Optional[str] = None,
) -> Optional[list[models.Flight]]:
Expand All @@ -261,15 +281,17 @@ async def search_flights(
departure_airport = "%"
if arrival_airport is None:
arrival_airport = "%"

results = await self.__pool.fetch(
"""
SELECT * FROM flights
WHERE departure_airport LIKE $1
AND arrival_airport LIKE $2
AND departure_time > $3::timestamp - interval '1 day'
AND departure_time < $3::timestamp + interval '1 day';
""",
departure_airport,
arrival_airport,
datetime.strptime(date, "%Y-%m-%d"),
timeout=10,
)
flights = [models.Flight.model_validate(dict(r)) for r in results]
Expand Down
210 changes: 68 additions & 142 deletions langchain_tools_demo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,182 +13,108 @@
# limitations under the License.

import os
from typing import Optional

import google.auth.transport.requests # type: ignore
import google.oauth2.id_token # type: ignore
import requests
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent import AgentExecutor
from langchain.globals import set_verbose # type: ignore
from langchain.llms.vertexai import VertexAI
from langchain.memory import ConversationBufferMemory
from langchain.tools import tool
from pydantic.v1 import BaseModel, Field
from langchain.prompts.chat import ChatPromptTemplate

DEBUG = bool(os.getenv("DEBUG", default=False))
BASE_URL = os.getenv("BASE_URL", default="http://127.0.0.1:8080")
from tools import convert_date, tools

set_verbose(bool(os.getenv("DEBUG", default=True)))


# Agent
def init_agent() -> AgentExecutor:
"""Load an agent executor with tools and LLM"""
print("Initializing agent..")
llm = VertexAI(max_output_tokens=512, verbose=DEBUG)
llm = VertexAI(max_output_tokens=512)
memory = ConversationBufferMemory(
memory_key="chat_history",
memory_key="chat_history", input_key="input", output_key="output"
)

agent = initialize_agent(
tools,
llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=DEBUG,
memory=memory,
handle_parsing_errors=True,
max_iterations=3,
early_stopping_method="generate",
return_intermediate_steps=True,
)
agent.agent.llm_chain.verbose = DEBUG # type: ignore

# Create new prompt template
tool_strings = "\n".join([f"> {tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = FORMAT_INSTRUCTIONS.format(
tool_names=tool_names,
)
date = convert_date("today")
today = f"Today is {date}."
template = "\n\n".join([PREFIX, tool_strings, format_instructions, SUFFIX, today])
human_message_template = "{input}\n\n{agent_scratchpad}"
prompt = ChatPromptTemplate.from_messages(
[("system", template), ("human", human_message_template)]
)
agent.agent.llm_chain.prompt = prompt # type: ignore
return agent


# Helper functions
def get_request(url: str, params: dict) -> requests.Response:
"""Helper method to make backend requests"""
if "http://" in url:
response = requests.get(
url,
params=params,
)
else:
response = requests.get(
url,
params=params,
headers={"Authorization": f"Bearer {get_id_token(url)}"},
)
return response


def get_id_token(url: str) -> str:
"""Helper method to generate ID tokens for authenticated requests"""
auth_req = google.auth.transport.requests.Request()
return google.oauth2.id_token.fetch_id_token(auth_req, url)

PREFIX = """SFO Airport Assistant helps travelers find their way at the airport.
def get_date():
from datetime import datetime
Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to
complex multi-query questions that require passing results from one query to another. As a language model, Assistant is
able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand.
now = datetime.now()
return now.strftime("%Y-%m-%dT%H:%M:%S")
Overall, Assistant is a powerful tool that can help answer a wide range of questions pertaining to the San
Francisco Airport. SFO Airport Assistant is here to assist. It currently does not have access to user info.
TOOLS:
------
# Arg Schema for tools
class IdInput(BaseModel):
id: int = Field(description="Unique identifier")
Assistant has access to the following tools:"""

FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name)
and an action_input key (tool input).
class QueryInput(BaseModel):
query: str = Field(description="Search query")
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
class ListFlights(BaseModel):
departure_airport: Optional[str] = Field(
description="Departure airport 3-letter code"
)
arrival_airport: Optional[str] = Field(description="Arrival airport 3-letter code")
date: str = Field(description="Date of flight departure", default=get_date())


# Tool Functions
@tool(
"Get Flight",
args_schema=IdInput,
)
def get_flight(id: int):
"""
Use this tool to get info for a specific flight.
Takes an id and returns info on the flight.
"""
response = get_request(
f"{BASE_URL}/flights",
{"flight_id": id},
)
if response.status_code != 200:
return f"Error trying to find flight: {response.text}"

return response.json()


@tool(
"List Flights",
args_schema=ListFlights,
)
def list_flights(departure_airport: str, arrival_airport: str, date: str):
"""Use this tool to list all flights matching search criteria."""
response = get_request(
f"{BASE_URL}/flights/search",
{
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"date": date,
},
)
if response.status_code != 200:
return f"Error searching flights: {response.text}"

return response.json()


@tool("Get Amenity", args_schema=IdInput)
def get_amenity(id: int):
"""
Use this tool to get info for a specific airport amenity.
Takes an id and returns info on the amenity.
Always use the id from the search_amenities tool.
"""
response = get_request(
f"{BASE_URL}/amenities",
{"id": id},
)
if response.status_code != 200:
return f"Error trying to find amenity: {response.text}"
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
return response.json()


@tool("Search Amenities", args_schema=QueryInput)
def search_amenities(query: str):
"""Use this tool to recommended airport amenities at SFO.
Returns several amenities that are related to the query.
Only recommend amenities that are returned by this query.
"""
response = get_request(
f"{BASE_URL}/amenities/search", {"top_k": "5", "query": query}
)
if response.status_code != 200:
return f"Error searching amenities: {response.text}"

return response.json()


@tool(
"Get Airport",
args_schema=IdInput,
)
def get_airport(id: int):
"""
Use this tool to get info for a specific airport.
Takes an id and returns info on the airport.
Always use the id from the search_airports tool.
"""
response = get_request(
f"{BASE_URL}/airports",
{"id": id},
)
if response.status_code != 200:
return f"Error trying to find airport: {response.text}"
Follow this format:
return response.json()
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""

SUFFIX = """Begin! Use tools if necessary. Respond directly if appropriate.
If using a tool, reminder to ALWAYS respond with a valid json blob of a single action.
Format is Action:```$JSON_BLOB```then Observation:.
Thought:
# Tools for agent
tools = [get_flight, list_flights, get_amenity, search_amenities, get_airport]
Previous conversation history:
{chat_history}
"""
3 changes: 2 additions & 1 deletion langchain_tools_demo/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ black==23.9.1
pytest==7.4.0
mypy==1.5.1
isort==5.12.0
types-requests==2.31.0.7
types-requests==2.31.0.7
types-python-dateutil==2.8.19.14
4 changes: 3 additions & 1 deletion langchain_tools_demo/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ body {

.chat-content {
overflow-y: scroll;
margin-bottom: 10px;
}

div.chat-content>span {
Expand Down Expand Up @@ -125,6 +126,7 @@ div.chat-wrapper div.chat-content span.user {
width: 100%;
display: none;
}

.mdl-progress.mdl-progress__indeterminate>.bar1 {
background-color: #394867
}
}
Loading

0 comments on commit 499f6a7

Please sign in to comment.