Skip to content

Commit

Permalink
fix: tests, examples
Browse files Browse the repository at this point in the history
  • Loading branch information
amirai21 committed Aug 21, 2024
1 parent a66d4e1 commit 59cc5f8
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 83 deletions.
2 changes: 1 addition & 1 deletion examples/studio/chat/chat_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
shnokel = DocumentSchema(
id=str(uuid.uuid4()),
content="Shnokel Corp. TL;DR Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, "
content="Shnokel Corp. Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, "
"reported a 20% increase in revenue this year, reaching $200 million. The successful deployment of "
"our advanced solar panels, SolarFlex, accounted for 40% of our sales. We entered new markets in Europe "
"and have plans to develop wind energy projects next year. Our commitment to reducing environmental "
Expand Down
38 changes: 23 additions & 15 deletions examples/studio/chat/chat_function_calling.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json

from ai21 import AI21Client
from ai21.logger import set_verbose
from ai21.models.chat import ChatMessage, ToolMessage
from ai21.models.chat.function_tool_definition import FunctionToolDefinition
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.models.chat.tool_parameters import ToolParameters
from ai21.models.chat import FunctionToolDefinition
from ai21.models.chat import ToolDefinition
from ai21.models.chat import ToolParameters

set_verbose(True)


def get_order_delivery_date(order_id: str) -> str:
print(f"Getting delivery date from database for order ID: {order_id}...")
print(f"Retrieving the delivery date for order ID: {order_id} from the database...")
return "2025-05-04"


Expand All @@ -27,7 +29,7 @@ def get_order_delivery_date(order_id: str) -> str:
type="function",
function=FunctionToolDefinition(
name="get_order_delivery_date",
description="Get the delivery date for a given order ID",
description="Retrieve the delivery date associated with the specified order ID",
parameters=ToolParameters(
type="object",
properties={"order_id": {"type": "string", "description": "The customer's order ID."}},
Expand All @@ -42,31 +44,37 @@ def get_order_delivery_date(order_id: str) -> str:

response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools)

print(response)
""" AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations.
The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function
to get the delivery date for the user's order. After retrieving the delivery date, we pass the response back
to the AI model to continue the conversation, using the ToolMessage object. """

assistant_message = response.choices[0].message
tool_calls = assistant_message.tool_calls
messages.append(assistant_message) # Adding the assistant message to the chat history

delivery_date = None
tool_calls = assistant_message.tool_calls
if tool_calls:
tool_call = tool_calls[0]
if tool_call.function.name == "get_order_delivery_date":
func_arguments = tool_call.function.arguments
if "order_id" in func_arguments:
# extract the order ID from the function arguments logic... (in this case it's just 1 argument)
order_id = func_arguments
delivery_date = get_order_delivery_date(order_id)
print(f"Delivery date for order ID {order_id}: {delivery_date}")
func_args_dict = json.loads(func_arguments)

if "order_id" in func_args_dict:
delivery_date = get_order_delivery_date(func_args_dict["order_id"])
else:
print("order_id not found in function arguments")
else:
print(f"Unexpected tool call found - {tool_call.function.name}")
else:
print("No tool calls found.")
print("No tool calls found")


if delivery_date is not None:
"""Continue the conversation by passing the delivery date back to the AI model:"""

tool_message = ToolMessage(role="tool", tool_call_id=tool_calls[0].id, content=delivery_date)
messages.append(assistant_message)
messages.append(tool_message)

response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools)
print(response)
print(response.choices[0].message.content)
55 changes: 34 additions & 21 deletions examples/studio/chat/chat_function_calling_multiple_tools.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import json

from ai21 import AI21Client
from ai21.logger import set_verbose
from ai21.models.chat import ChatMessage, ToolMessage
from ai21.models.chat.function_tool_definition import FunctionToolDefinition
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.models.chat.tool_parameters import ToolParameters

set_verbose(True)
# set_verbose(True)


def get_weather(place: str, date: str) -> str:
print(f"Getting the expected weather at {place} during {date} from the internet...")
"""
Retrieve the expected weather for a specified location and date.
"""
print(f"Fetching expected weather for {place} on {date}...")
return "32 celsius"


def get_sunset_hour(place: str, date: str):
print(f"Getting the expected sunset hour at {place} during {date} from the internet...")
def get_sunset_hour(place: str, date: str) -> str:
"""
Fetch the expected sunset time for a given location and date.
"""
print(f"Fetching expected sunset time for {place} on {date}...")
return "7 pm"


Expand All @@ -26,7 +31,7 @@ def get_sunset_hour(place: str, date: str):
content="You are a helpful assistant. Use the supplied tools to assist the user.",
),
ChatMessage(
role="user", content="Hi, can you assist me to get info about the weather and expected sunset in Tel Aviv?"
role="user", content="Hello, could you help me find out the weather forecast and sunset time for London?"
),
ChatMessage(role="assistant", content="Hi there! I can help with that. On which date?"),
ChatMessage(role="user", content="At 2024-08-27"),
Expand All @@ -36,12 +41,12 @@ def get_sunset_hour(place: str, date: str):
type="function",
function=FunctionToolDefinition(
name="get_sunset_hour",
description="Search the internet for the sunset hour at a given place on a given date",
description="Fetch the expected sunset time for a given location and date.",
parameters=ToolParameters(
type="object",
properties={
"place": {"type": "string", "description": "The place to look for the weather at"},
"date": {"type": "string", "description": "The date to look for the weather at"},
"place": {"type": "string", "description": "The location for which the weather is being queried."},
"date": {"type": "string", "description": "The date for which the weather is being queried."},
},
required=["place", "date"],
),
Expand All @@ -52,12 +57,12 @@ def get_sunset_hour(place: str, date: str):
type="function",
function=FunctionToolDefinition(
name="get_weather",
description="Search the internet for the weather at a given place on a given date",
description="Retrieve the expected weather for a specified location and date.",
parameters=ToolParameters(
type="object",
properties={
"place": {"type": "string", "description": "The place to look for the weather at"},
"date": {"type": "string", "description": "The date to look for the weather at"},
"place": {"type": "string", "description": "The location for which the weather is being queried."},
"date": {"type": "string", "description": "The date for which the weather is being queried."},
},
required=["place", "date"],
),
Expand All @@ -66,49 +71,57 @@ def get_sunset_hour(place: str, date: str):

tools = [get_sunset_tool, get_weather_tool]

client = AI21Client()
client = AI21Client(api_host="https://api-stage.ai21.com", api_key="F6iFeKlMsisusyhtoy1ZUj4bRPhEd6sf")

response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools)

print(response.choices[0].message)
""" AI models can be error-prone, it's crucial to ensure that the tool calls align with the expectations.
The below code snippet demonstrates how to handle tool calls in the response and invoke the tool function
to get the delivery date for the user's order. After retrieving the delivery date, we pass the response back
to the AI model to continue the conversation, using the ToolMessage object. """

assistant_message = response.choices[0].message
messages.append(assistant_message)
messages.append(assistant_message) # Adding the assistant message to the chat history

tool_calls = assistant_message.tool_calls

too_call_id_to_result = {}
tool_calls = assistant_message.tool_calls
if tool_calls:
for tool_call in tool_calls:
if tool_call.function.name == "get_weather":
"""Verify get_weather tool call arguments and invoke the function to get the weather forecast:"""
func_arguments = tool_call.function.arguments
args = json.loads(func_arguments)

if "place" in args and "date" in args:
result = get_weather(args["place"], args["date"])
too_call_id_to_result[tool_call.id] = result
else:
print(f"Got unexpected arguments in function call - {args}")

elif tool_call.function.name == "get_sunset_hour":
"""Verify get_sunset_hour tool call arguments and invoke the function to get the weather forecast:"""
func_arguments = tool_call.function.arguments
args = json.loads(func_arguments)

if "place" in args and "date" in args:
result = get_sunset_hour(args["place"], args["date"])
too_call_id_to_result[tool_call.id] = result
else:
print(f"Got unexpected arguments in function call - {args}")

else:
print(f"Unexpected tool call found - {tool_call.function.name}")
else:
print("No tool calls found.")
print("No tool calls found")


if too_call_id_to_result:
"""Continue the conversation by passing the sunset and weather back to the AI model:"""

for tool_id_called, result in too_call_id_to_result.items():
tool_message = ToolMessage(role="tool", tool_call_id=tool_id_called, content=str(result))
messages.append(tool_message)

for message in messages:
print(message)

response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools)
print(response.choices[0].message)
print(response.choices[0].message.content)
5 changes: 4 additions & 1 deletion examples/studio/chat/chat_response_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ZooTicketsOrder(BaseModel):
)
]


client = AI21Client(api_host="https://api-stage.ai21.com", api_key="F6iFeKlMsisusyhtoy1ZUj4bRPhEd6sf")

response = client.chat.completions.create(
Expand All @@ -42,9 +43,11 @@ class ZooTicketsOrder(BaseModel):
response_format=ResponseFormat(type="json_object"),
)

print(response)

try:
order = ZooTicketsOrder.model_validate_json(response.choices[0].message.content)
print("Here is the order:")
print("Zoo tickets order details JSON:")
print(order.model_dump_json(indent=4))
except ValidationError as exc:
print(exc)
4 changes: 4 additions & 0 deletions tests/integration_tests/clients/test_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
("chat/chat_completions.py",),
("chat/chat_completions_jamba_instruct.py",),
("chat/stream_chat_completions.py",),
("chat/chat_documents.py",),
("chat/chat_function_calling.py",),
("chat/chat_function_calling_multiple_tools.py",),
("chat/chat_response_format.py",),
# ("custom_model.py", ),
# ('custom_model_completion.py', ),
# ("dataset.py", ),
Expand Down
Loading

0 comments on commit 59cc5f8

Please sign in to comment.