Skip to content

Commit

Permalink
feat: add cymbal air passenger policy (#265)
Browse files Browse the repository at this point in the history
Co-authored-by: Mahyar Ebadi <[email protected]>
  • Loading branch information
Yuan325 and mahyareb authored Mar 21, 2024
1 parent 5195d6c commit 199a67c
Show file tree
Hide file tree
Showing 23 changed files with 750 additions and 77 deletions.
25 changes: 25 additions & 0 deletions data/cymbalair_policy.csv

Large diffs are not rendered by default.

52 changes: 49 additions & 3 deletions llm_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from contextlib import asynccontextmanager
from typing import Any, Optional
Expand Down Expand Up @@ -71,6 +72,11 @@ async def index(request: Request):
if "user_info" in request.session
else None
),
"user_name": (
request.session["user_info"]["name"]
if "user_info" in request.session
else None
),
},
)

Expand Down Expand Up @@ -144,10 +150,50 @@ async def chat_handler(request: Request, prompt: str = Body(embed=True)):
# Add user message to chat history
request.session["history"].append({"type": "human", "data": {"content": prompt}})
orchestrator = request.app.state.orchestrator
output = await orchestrator.user_session_invoke(request.session["uuid"], prompt)
response = await orchestrator.user_session_invoke(request.session["uuid"], prompt)
output = response.get("output")
confirmation = response.get("confirmation")
# Return assistant response
request.session["history"].append({"type": "ai", "data": {"content": output}})
return markdown(output)
if confirmation:
return json.dumps({"type": "confirmation", "content": confirmation})
else:
request.session["history"].append({"type": "ai", "data": {"content": output}})
return json.dumps({"type": "message", "content": markdown(output)})


@routes.post("/book/flight", response_class=PlainTextResponse)
async def book_flight(request: Request, params: str = Body(embed=True)):
"""Handler for LangChain chat requests"""
# Retrieve the params for the booking
if not params:
raise HTTPException(status_code=400, detail="Error: No booking params")
if "uuid" not in request.session:
raise HTTPException(
status_code=400, detail="Error: Invoke index handler before start chatting"
)
orchestrator = request.app.state.orchestrator
response = await orchestrator.user_session_insert_ticket(
request.session["uuid"], params
)
# Note in the history, that the ticket has been successfully booked
request.session["history"].append(
{"type": "ai", "data": {"content": "I have booked your ticket."}}
)
return response


@routes.post("/book/flight/decline", response_class=PlainTextResponse)
async def decline_flight(request: Request):
"""Handler for LangChain chat requests"""
# Note in the history, that the ticket was not booked
# This is helpful in case of reloads so there doesn't seem to be a break in communication.
request.session["history"].append(
{"type": "ai", "data": {"content": "Please confirm if you would like to book."}}
)
request.session["history"].append(
{"type": "human", "data": {"content": "I changed my mind."}}
)
return None


@routes.post("/reset")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytz import timezone

from ..orchestrator import BaseOrchestrator, classproperty
from .tools import initialize_tools
from .tools import get_confirmation_needing_tools, initialize_tools, insert_ticket

set_verbose(bool(os.getenv("DEBUG", default=False)))
BASE_HISTORY = {
Expand Down Expand Up @@ -93,6 +93,9 @@ async def invoke(self, prompt: str) -> Dict[str, Any]:
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")
return response

async def insert_ticket(self, params: str):
return await insert_ticket(self.client, params)

def reset_memory(self, base_message: List[BaseMessage]):
self.memory.clear()
self.memory.chat_memory = ChatMessageHistory(messages=base_message)
Expand All @@ -113,6 +116,22 @@ def kind(cls):
def user_session_exist(self, uuid: str) -> bool:
return uuid in self._user_sessions

async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
user_session = self.get_user_session(uuid)
response = await user_session.insert_ticket(params)
return response

def check_and_add_confirmations(cls, response: Dict[str, Any]):
for step in response.get("intermediate_steps") or []:
if len(step) > 0:
# Find the called tool in the step
called_tool = step[0]
# Check to see if the agent has made a decision to call Prepare Insert Ticket
# This tool is a no-op and requires user confirmation before continuing
if called_tool.tool in cls.confirmation_needing_tools:
return {"tool": called_tool.tool, "params": called_tool.tool_input}
return None

async def user_session_create(self, session: dict[str, Any]):
"""Create and load an agent executor with tools and LLM."""
print("Initializing agent..")
Expand All @@ -127,12 +146,21 @@ async def user_session_create(self, session: dict[str, Any]):
prompt = self.create_prompt_template(tools)
agent = UserAgent.initialize_agent(client, tools, history, prompt, self.MODEL)
self._user_sessions[id] = agent
self.confirmation_needing_tools = get_confirmation_needing_tools()
self.client = client

async def user_session_invoke(self, uuid: str, prompt: str) -> str:
async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]:
user_session = self.get_user_session(uuid)
# Send prompt to LLM
response = await user_session.invoke(prompt)
return response["output"]
agent_response = await user_session.invoke(prompt)
# Check for calls that may require confirmation to proceed
confirmation = self.check_and_add_confirmations(agent_response)
# Build final response
response = {}
response["output"] = agent_response.get("output")
if confirmation:
response["confirmation"] = confirmation
return response

def user_session_reset(self, session: dict[str, Any], uuid: str):
user_session = self.get_user_session(uuid)
Expand Down Expand Up @@ -223,21 +251,22 @@ def close_clients(self):

PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs.
Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its
Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its
hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer
service!
Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist
with a wide range of tasks, from answering simple questions to complex multi-query questions that
require passing results from one query to another. Using the latest AI models, Assistant is able to
Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist
with a wide range of tasks, from answering simple questions to complex multi-query questions that
require passing results from one query to another. Using the latest AI models, Assistant is able to
generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand.
Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air
as well as ammenities of San Francisco Airport.
"""
TOOLS_PREFIX = """TOOLS:
as well as ammenities of San Francisco Airport."""

TOOLS_PREFIX = """
TOOLS:
------
Assistant has access to the following tools:"""
Expand Down
41 changes: 25 additions & 16 deletions llm_demo/orchestrator/langchain_tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from datetime import datetime
from typing import Optional
from typing import Dict, Optional

import aiohttp
import google.oauth2.id_token # type: ignore
Expand Down Expand Up @@ -191,25 +192,29 @@ async def insert_ticket(
departure_time: datetime,
arrival_time: datetime,
):
response = await client.post(
url=f"{BASE_URL}/tickets/insert",
params={
"airline": airline,
"flight_number": flight_number,
"departure_airport": departure_airport,
"arrival_airport": arrival_airport,
"departure_time": departure_time.strftime("%Y-%m-%d %H:%M:%S"),
"arrival_time": arrival_time.strftime("%Y-%m-%d %H:%M:%S"),
},
headers=get_headers(client),
)

response = await response.json()
return response
return f"Booking ticket on {airline} {flight_number}"

return insert_ticket


async def insert_ticket(client: aiohttp.ClientSession, params: str):
ticket_info = json.loads(params)
response = await client.post(
url=f"{BASE_URL}/tickets/insert",
params={
"airline": ticket_info.get("airline"),
"flight_number": ticket_info.get("flight_number"),
"departure_airport": ticket_info.get("departure_airport"),
"arrival_airport": ticket_info.get("arrival_airport"),
"departure_time": ticket_info.get("departure_time").replace("T", " "),
"arrival_time": ticket_info.get("arrival_time").replace("T", " "),
},
headers=get_headers(client),
)
response = await response.json()
return response


def generate_list_tickets(client: aiohttp.ClientSession):
async def list_tickets():
response = await client.get(
Expand Down Expand Up @@ -370,3 +375,7 @@ async def initialize_tools(client: aiohttp.ClientSession):
""",
),
]


def get_confirmation_needing_tools():
return ["Insert Ticket"]
6 changes: 5 additions & 1 deletion llm_demo/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def user_session_create(self, session: dict[str, Any]):
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def user_session_invoke(self, uuid: str, prompt: str) -> str:
async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]:
"""Invoke user session and return a response from llm orchestrator."""
raise NotImplementedError("Subclass should implement this!")

Expand All @@ -57,6 +57,10 @@ def user_session_reset(self, session: dict[str, Any], uuid: str):
def get_user_session(self, uuid: str) -> Any:
raise NotImplementedError("Subclass should implement this!")

@abstractmethod
async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
raise NotImplementedError("Subclass should implement this!")

def set_user_session_header(self, uuid: str, user_id_token: str):
user_session = self.get_user_session(uuid)
user_session.client.headers["User-Id-Token"] = f"Bearer {user_id_token}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
)

from ..orchestrator import BaseOrchestrator, classproperty
from .functions import BASE_URL, assistant_tool, function_request, get_headers
from .functions import (
BASE_URL,
assistant_tool,
function_request,
get_confirmation_needing_tools,
get_headers,
insert_ticket,
)

DEBUG = os.getenv("DEBUG", default=False)
BASE_HISTORY = {
Expand Down Expand Up @@ -66,11 +73,22 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]:
self.debug_log(f"Prompt:\n{prompt}\n\nQuestion: {input_prompt}.")
self.debug_log(f"\nFunction call response:\n{model_response}")
part_response = model_response.candidates[0].content.parts[0]
confirmation = None

# implement multi turn chat with while loop
while "function_call" in part_response._raw_part:
function_call = MessageToDict(part_response.function_call._pb)
function_response = await self.request_function(function_call)
function_name = function_call.get("name")
if function_name in get_confirmation_needing_tools():
function_response = self.confirmation_response(
function_name, function_call.get("args")
)
confirmation = {
"tool": function_name,
"params": function_call.get("args"),
}
else:
function_response = await self.request_function(function_call)
self.debug_log(f"Function response:\n{function_response}")
part = Part.from_function_response(
name=function_call["name"],
Expand All @@ -84,7 +102,7 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]:
if "text" in part_response._raw_part:
content = part_response.text
self.debug_log(f"Output content: {content}")
return {"output": content}
return {"output": content, "confirmation": confirmation}
else:
raise HTTPException(
status_code=500, detail="Error: Chat model response unknown"
Expand All @@ -93,7 +111,7 @@ async def invoke(self, input_prompt: str) -> Dict[str, Any]:
def get_prompt(self) -> str:
formatter = "%A, %m/%d/%Y, %H:%M:%S"
now = datetime.now(timezone("US/Pacific")).strftime("%A, %m/%d/%Y, %H:%M:%S")
prompt = f"{PREFIX}. Today's date and current time is {now}."
prompt = f"{PREFIX}\nToday's date and current time is {now}."
return prompt

def debug_log(self, output: str) -> None:
Expand All @@ -110,6 +128,11 @@ def request_chat_model(self, prompt: str):
raise HTTPException(status_code=500, detail=f"Error invoking agent: {err}")
return model_response

def confirmation_response(self, function_name, function_params):
if function_name == "insert_ticket":
return f"Booking ticket on {function_params.get('airline')} {function_params.get('flight_number')}"
return ""

async def request_function(self, function_call):
url = function_request(function_call["name"])
params = function_call["args"]
Expand All @@ -122,6 +145,9 @@ async def request_function(self, function_call):
response = await response.json()
return response

async def insert_ticket(self, params: str):
return await insert_ticket(self.client, params)

def reset_memory(self, model: str):
"""reinitiate chat model to reset memory."""
del self.chat
Expand All @@ -144,6 +170,11 @@ def kind(cls):
def user_session_exist(self, uuid: str) -> bool:
return uuid in self._user_sessions

async def user_session_insert_ticket(self, uuid: str, params: str) -> Any:
user_session = self.get_user_session(uuid)
response = await user_session.insert_ticket(params)
return response

async def user_session_create(self, session: dict[str, Any]):
"""Create and load an agent executor with tools and LLM."""
print("Initializing agent..")
Expand All @@ -155,12 +186,13 @@ async def user_session_create(self, session: dict[str, Any]):
client = await self.create_client_session()
chat = UserChatModel.initialize_chat_model(client, self.MODEL)
self._user_sessions[id] = chat
self.client = client

async def user_session_invoke(self, uuid: str, prompt: str) -> str:
async def user_session_invoke(self, uuid: str, prompt: str) -> dict[str, Any]:
user_session = self.get_user_session(uuid)
# Send prompt to LLM
response = await user_session.invoke(prompt)
return response["output"]
return response

def user_session_reset(self, session: dict[str, Any], uuid: str):
user_session = self.get_user_session(uuid)
Expand Down Expand Up @@ -205,16 +237,16 @@ def close_clients(self):

PREFIX = """The Cymbal Air Customer Service Assistant helps customers of Cymbal Air with their travel needs.
Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its
Cymbal Air (airline unique two letter identifier as CY) is a passenger airline offering convenient flights to many cities around the world from its
hub in San Francisco. Cymbal Air takes pride in using the latest technology to offer the best customer
service!
Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist
with a wide range of tasks, from answering simple questions to complex multi-query questions that
require passing results from one query to another. Using the latest AI models, Assistant is able to
Cymbal Air Customer Service Assistant (or just "Assistant" for short) is designed to assist
with a wide range of tasks, from answering simple questions to complex multi-query questions that
require passing results from one query to another. Using the latest AI models, Assistant is able to
generate human-like text based on the input it receives, allowing it to engage in natural-sounding
conversations and provide responses that are coherent and relevant to the topic at hand.
Assistant is a powerful tool that can help answer a wide range of questions pertaining to travel on Cymbal Air
as well as ammenities of San Francisco Airport.
as well as ammenities of San Francisco Airport.
"""
Loading

0 comments on commit 199a67c

Please sign in to comment.