Skip to content

Commit

Permalink
Fix for inconsistent messaging in tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
SailorJoe6 committed Sep 12, 2024
1 parent e22b27a commit a350670
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions samples/apps/autogen-studio/autogenstudio/workflowmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,12 @@ def __init__(
self.a_human_input_timeout = a_human_input_timeout
self.connection_id = connection_id

def shorten_text(self, text: str, n: int) -> str:
# If the text is longer than n, shorten it and add an ellipsis
if len(text) > n:
return text[:n] + "..."
return text

def receive(
self,
message: Union[Dict, str],
Expand All @@ -1040,7 +1046,7 @@ def receive(
tool_calls = [
{
"name": func["name"],
"arguments": ', '.join(f'{key}: {value}' for key, value in json.loads(func["arguments"]).items())
"arguments": ', '.join(f'{key}: {json.dumps(self.shorten_text(value,20))}' for key, value in json.loads(func["arguments"]).items())
}
for func in [tc.get("function") for tc in message["tool_calls"]]
]
Expand All @@ -1067,12 +1073,11 @@ async def a_receive(
tool_calls = [
{
"name": func["name"],
#"arguments": ', '.join(f'{key}: {value}' for key, value in json.loads(func["arguments"]).items())
"arguments": ', '.join(f'{key}: {json.dumps(self.shorten_text(value,20))}' for key, value in json.loads(func["arguments"]).items())
}
for func in [tc.get("function") for tc in message["tool_calls"]]
]
#tool_call_msgs = [f"requested tool call: {func.get("name")}({func.get("arguments")})" for func in tool_calls]
tool_call_msgs = [f"requested tool call: {func.get("name")}" for func in tool_calls]
tool_call_msgs = [f"requested tool call: {func.get("name")}({func.get("arguments")})" for func in tool_calls]
if not message.get("content") == None:
tool_call_msgs.insert(0, message.get("content"))
new_message = copy.deepcopy(message)
Expand Down Expand Up @@ -1160,7 +1165,7 @@ def receive(
self.message_processor(sender, self, message, request_reply, silent, sender_type="groupchat")
super().receive(message, sender, request_reply, silent)

def shorten_text(text: str, n: int) -> str:
def shorten_text(self, text: str, n: int) -> str:
# If the text is longer than n, shorten it and add an ellipsis
if len(text) > n:
return text[:n] + "..."
Expand Down Expand Up @@ -1188,7 +1193,7 @@ async def a_receive(
tool_call_msgs.insert(0, message.get("content"))
new_message = copy.deepcopy(message)
new_message["content"] = "\n".join(tool_call_msgs)
await self.a_message_processor(sender, self, message, request_reply, silent, sender_type="groupchat")
await self.a_message_processor(sender, self, new_message, request_reply, silent, sender_type="groupchat")
else:
await self.a_message_processor(sender, self, message, request_reply, silent, sender_type="groupchat")
elif self.message_processor:
Expand Down

0 comments on commit a350670

Please sign in to comment.