Skip to content

Commit

Permalink
refactor: add react agent as abstraction of existing example agents (#84
Browse files Browse the repository at this point in the history
)

* refactor agent

* fix: invalid names of sdxl_turbo
  • Loading branch information
dongyuanjushi authored Jul 12, 2024
1 parent 2f52554 commit d7f0e6d
Show file tree
Hide file tree
Showing 41 changed files with 1,096 additions and 1,633 deletions.
2 changes: 2 additions & 0 deletions pyopenagi/agents/agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def activate_agent(self, agent_name, task_input):

agent_class = self.load_agent_instance(agent_name)

print(type(agent_class))

agent = agent_class(
agent_name = agent_name,
task_input = task_input,
Expand Down
91 changes: 52 additions & 39 deletions pyopenagi/agents/base.py → pyopenagi/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from ..utils.chat_template import Query

from abc import ABC, abstractmethod
import importlib

class CustomizedThread(Thread):
def __init__(self, target, args=()):
super().__init__()
Expand All @@ -36,7 +37,7 @@ def join(self):
super().join()
return self.result

class BaseAgent(ABC):
class BaseAgent:
def __init__(self,
agent_name,
task_input,
Expand All @@ -47,11 +48,15 @@ def __init__(self,
):
self.agent_name = agent_name
self.config = self.load_config()
self.tools = self.config["tools"]
self.tool_names = self.config["tools"]


self.llm = llm
self.agent_process_queue = agent_process_queue
self.agent_process_factory = agent_process_factory
self.tool_list: dict = None
self.tool_list = dict()
self.tools = []
self.load_tools(self.tool_names)

self.start_time = None
self.end_time = None
Expand All @@ -69,7 +74,6 @@ def __init__(self,
self.set_status("active")
self.set_created_time(time.time())

self.build_system_instruction()

def run(self):
'''Execute each step to finish the task.'''
Expand All @@ -79,65 +83,74 @@ def run(self):
def build_system_instruction(self):
pass

def check_workflow(self, message):
try:
workflow = json.loads(message)
if not isinstance(workflow, list):
return None

for step in workflow:
if "message" not in step or "tool_use" not in step:
return None

return workflow

except json.JSONDecodeError:
return None

def automatic_workflow(self):
for i in range(self.plan_max_fail_times):
try:
response, start_times, end_times, waiting_times, turnaround_times = self.get_response(
query = Query(
messages = self.messages,
tools = None
)
response, start_times, end_times, waiting_times, turnaround_times = self.get_response(
query = Query(
messages = self.messages,
tools = None
)
)

if self.rounds == 0:
self.set_start_time(start_times[0])

if self.rounds == 0:
self.set_start_time(start_times[0])
self.request_waiting_times.extend(waiting_times)
self.request_turnaround_times.extend(turnaround_times)

self.request_waiting_times.extend(waiting_times)
self.request_turnaround_times.extend(turnaround_times)
workflow = self.check_workflow(response.response_message)

workflow = json.loads(response.response_message)
self.rounds += 1
self.rounds += 1

if workflow:
return workflow

except Exception:
else:
self.messages.append(
{
"role": "assistant",
"content": f"Fail {i+1} times to generate a valid plan. I need to regenerate a plan"
}
)
continue
return None

def manual_workflow(self):
pass

def call_tools(self, tool_calls):
self.logger.log(f"***** It starts to call external tools *****\n", level="info")
tool_call_responses = None
for tool_call in tool_calls:
function_name = tool_call["name"]
function_to_call = self.tool_list[function_name]
function_params = tool_call["parameters"]
def snake_to_camel(self, snake_str):
components = snake_str.split('_')
return ''.join(x.title() for x in components)

def load_tools(self, tool_names):
for tool_name in tool_names:
org, name = tool_name.split("/")

module_name = ".".join(["pyopenagi", "tools", org, name])

try:
function_response = function_to_call.run(function_params)
if tool_call_responses is None:
tool_call_responses = f"I will call the {function_name} with the params as {function_params} to solve this. The tool response is {function_response}\n"
else:
tool_call_responses += f"I will call the {function_name} with the params as {function_params} to solve this. The tool response is {function_response}\n"
class_name = self.snake_to_camel(name)

except Exception:
continue
tool_module = importlib.import_module(module_name)

if tool_call_responses:
self.logger.log(f"At current step, {tool_call_responses}", level="info")
tool_class = getattr(tool_module, class_name)

else:
self.logger.log("At current step, I fail to call any tools.")
self.tool_list[name] = tool_class()

return tool_call_responses
self.tools.append(tool_class().get_tool_call_format())

def setup_logger(self):
logger = AgentLogger(self.agent_name, self.log_mode)
Expand Down
180 changes: 6 additions & 174 deletions pyopenagi/agents/example/academic_agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,6 @@
from ...react_agent import ReactAgent

from ...base import BaseAgent

import time

from ...agent_process import (
AgentProcess
)

import numpy as np

import argparse

from concurrent.futures import as_completed

from ....utils.chat_template import Query

from ....tools.online.arxiv import Arxiv

from ....tools.online.wikipedia import Wikipedia

import json

class AcademicAgent(BaseAgent):
class AcademicAgent(ReactAgent):
def __init__(self,
agent_name,
task_input,
Expand All @@ -30,158 +9,11 @@ def __init__(self,
agent_process_factory,
log_mode: str
):
BaseAgent.__init__(self, agent_name, task_input, llm, agent_process_queue, agent_process_factory, log_mode)

self.tool_list = {
"arxiv": Arxiv()
}

self.plan_max_fail_times = 3
self.tool_call_max_fail_times = 3

def build_system_instruction(self):
prefix = "".join(
[
"".join(self.config["description"]),
f'You are given the available tools from the tool list: {json.dumps(self.tools)} to help you solve problems.'
]
)
plan_instruction = "".join(
[
'Generate a plan of steps you need to take.',
'The plan must follow the json format as: '
'[{"message1": "message_value1","tool_use": [tool_name1, tool_name2,...]}'
'{"message2": "message_value2", "tool_use": [tool_name1, tool_name2,...]}'
'...]',
'In each step of the planned workflow, you must select the most related tool to use'
'An plan example can be:'
'[{"message": "Gather information from arxiv", "tool_use": ["arxiv"]}]'
'{"message", "Based on the gathered information, write a summarization", "tool_use": None'
]
)
exection_instruction = "".join(
[
'To execute each step, you need to output as the following json format:',
'{"observation": "What you have observed from the environment",',
'"thinking": "Your thought of the current situation",',
'"action": "Your action to take in current step",',
'"result": "The result of what you have done."}'
]
)
if self.workflow_mode == "manual":
self.messages.append(
{"role": "system", "content": prefix + exection_instruction}
)
else:
assert self.workflow_mode == "automatic"
self.messages.append(
{"role": "system", "content": prefix + plan_instruction + exection_instruction}
)

def automatic_workflow(self):
return super().automatic_workflow()
ReactAgent.__init__(self, agent_name, task_input, llm, agent_process_queue, agent_process_factory, log_mode)
self.workflow_mode = "automatic"

def manual_workflow(self):
workflow = [
{
"message": "use the arxiv tool to gather information",
"tool_use": ["arxiv"]
},
{
"message": "postprocess gathered information to fulfill the user's requrements",
"tool_use": None
}
]
return workflow
pass

def run(self):
task_input = self.task_input

self.messages.append(
{"role": "user", "content": task_input}
)
self.logger.log(f"{task_input}\n", level="info")

workflow = None

# generate plan
# workflow = self.automatic_workflow() # generate workflow by llm
workflow = self.manual_workflow() # define workflow manually

self.messages.append(
{"role": "assistant", "content": f"The workflow is {json.dumps(workflow)}"}
)

self.logger.log(f"Workflow: {workflow}\n", level="info")

final_result = ""

for i, step in enumerate(workflow):
message = step["message"]
tool_use = step["tool_use"]

prompt = f"\nAt step {self.rounds + 1}, you need to {message}. Focus on current step and do not be verbose!"
self.messages.append({
"role": "user",
"content": prompt
})

used_tools = self.tools if tool_use else None

response, start_times, end_times, waiting_times, turnaround_times = self.get_response(
query = Query(
messages = self.messages,
tools = used_tools
)
)
if self.rounds == 0:
self.set_start_time(start_times[0])

# execute action
response_message = response.response_message

tool_calls = response.tool_calls

self.request_waiting_times.extend(waiting_times)
self.request_turnaround_times.extend(turnaround_times)

if tool_calls:
for i in range(self.plan_max_fail_times):
tool_call_responses = self.call_tools(tool_calls=tool_calls)
if tool_call_responses:
if response_message is None:
response_message = tool_call_responses
break
else:
self.messages.append(
{
"role": "assistant",
"content": "Fail to call tools correctly. I need to redo the selection of tool and parameters"
}
)
continue

self.messages.append({
"role": "assistant",
"content": response_message
})

if i == len(workflow) - 1:
final_result = response_message

self.logger.log(f"At step {self.rounds + 1}, {response_message}\n", level="info")

self.rounds += 1

self.set_status("done")
self.set_end_time(time=time.time())

return {
"agent_name": self.agent_name,
"result": final_result,
"rounds": self.rounds,
"agent_waiting_time": self.start_time - self.created_time,
"agent_turnaround_time": self.end_time - self.created_time,
"request_waiting_times": self.request_waiting_times,
"request_turnaround_times": self.request_turnaround_times,
}
return super().run()
20 changes: 1 addition & 19 deletions pyopenagi/agents/example/academic_agent/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,7 @@
"You are an expert who is good at looking up and summarizing academic articles. "
],
"tools": [
{
"type": "function",
"function": {
"name": "arxiv",
"description": "Query articles or topics in arxiv",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Input query that describes what to search in arxiv"
}
},
"required": [
"query"
]
}
}
}
"arxiv/arxiv"
],
"meta": {
"author": "example",
Expand Down
Loading

0 comments on commit d7f0e6d

Please sign in to comment.