diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dea4142eb..5a6442e64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,6 +39,7 @@ repos: --disable-error-code=import-untyped, --disable-error-code=truthy-function, --follow-imports=skip, + --disable-error-code=override, ] # - repo: https://github.com/numpy/numpydoc # rev: v1.6.0 diff --git a/src/agentscope/agents/learnable_agent.py b/src/agentscope/agents/learnable_agent.py new file mode 100644 index 000000000..40c3cffb4 --- /dev/null +++ b/src/agentscope/agents/learnable_agent.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +""" LearnableAgent agent class for Agent """ +from abc import ABC +from typing import Optional, Union, Any, Callable, Type +from loguru import logger + +from agentscope.message import Msg +from agentscope.memory import MemoryBase, TemporaryMemory +from agentscope.agents.agent import AgentBase +from agentscope.service.retrieval.similarity import cos_sim + + +VALUE_ASSESSMENT_PROMPT = ( + "Please carefully consider the following record and assess whether it " + "contains information of sufficient value to be suitable for storage in " + "a knowledge base. " + "\nExample:\n" + "'The dragon is the only creature in the Chinese Zodiac that is " + "considered a divine animal.' → Answer 'yes' (because this is basic " + "knowledge about Chinese culture with widespread reference value for " + "understanding related topics)\n" + "Following these guidelines, please respond with 'yes' or 'no' to the " + "following record:\n\n" + "{record}" +) + +EXTRACTION_SUMMARY_PROMPT = ( + "Please read the following record, extract key knowledge points or " + "question-answer pairs, and provide a concise and clear summary. " + "\nExample:\n" + "Record: 'Due to the rotation of the Earth, we experience the " + "alternation of day and night. " + "The Earth completes one rotation every 24 hours.'\n" + "Summary: 'The Earth rotates once every 24 hours, which leads to the " + "phenomenon of day and night alternation.'\n\n" + "{record}" +) + + +class LearnableAgent(AgentBase, ABC): + """Class for LearnableAgent""" + + def __init__( + self, + name: str, + vdb_path: str, + vdb_cls: Type[MemoryBase] = TemporaryMemory, + config: Optional[dict] = None, + sys_prompt: Optional[str] = None, + model: Optional[Union[Callable[..., Any], str]] = None, + embedding_model: Union[str, Callable] = None, + metric: Callable = cos_sim, + assess_prompt: str = VALUE_ASSESSMENT_PROMPT, + extract_prompt: str = EXTRACTION_SUMMARY_PROMPT, + ) -> None: + super().__init__(name, config, sys_prompt, model) + # Notice: [Memory] is for short-term, current conversation, and will + # not persist after the agent is closed. + # [Vector database] is considered long-term, will be reloaded whenever + # agent is invoked + # Build vector database for saving knowledge + self.vdb = vdb_cls( + config, + embedding_model=embedding_model, + vdb_path=vdb_path, + ) + self.metric = lambda x, y: metric(x, y).content + self.assess_prompt = assess_prompt + self.extract_prompt = extract_prompt + + def reply(self, x: dict = None) -> dict: + """Forward method for agent""" + # defer the forward function implementation to example agents + raise NotImplementedError + + def learn_from_chat(self) -> None: + """ + Iterates through the messages in the learner's memory and processes + each message to potentially learn from it. Messages originating + from the learner itself are ignored. The memory is reset after + processing. + + This function calls the `archive_valuable_msg` method on each message + to decide whether to store the message information into the + knowledge base. + """ + if self.memory.size() > 0: + for msg in self.memory: + # Ignore msg from itselves to avoid duplication + if msg.get("name") != self.name: + self.archive_valuable_msg(msg) + self.memory.reset() + + def archive_valuable_msg(self, msg: dict) -> None: + """ + Evaluates a single message to determine whether it should be stored + in the knowledge base. The method generates prompts to assess the + value of the message and to extract a summary if the message is + deemed valuable. + + Args: + msg (dict): A dictionary representing the message to be + considered for storage. The dictionary typically contains + keys such as 'name' and 'content'. + """ + # Consider whether to deposit message into the knowledge base + prompt = self.assess_prompt.format_map( + { + "record": msg.content, + }, + ) + res = self.model([Msg(self.name, prompt)]) + + logger.info( + f"{self.name}:\n {msg.content} \n " f"accessing results: {res}.", + ) + + if "yes" in res.lower(): + prompt = self.extract_prompt.format_map( + { + "record": msg.content, + }, + ) + res = self.model([Msg(self.name, prompt)]) + emb = self._openai_embedding(res) + self.vdb.add(Msg(self.name, res, embedding=emb), embed=False) + logger.info(f"Saving {res} in {self.name}'s vdb.") + + def close(self) -> None: + """ + Saves the current state of the vecter database (vdb) to a memory file. + This method should be called before the termination of the program + to ensure that learned information is not lost. + """ + self.vdb.export() diff --git a/src/agentscope/memory/memory.py b/src/agentscope/memory/memory.py index a536f511b..ae31b7c2d 100644 --- a/src/agentscope/memory/memory.py +++ b/src/agentscope/memory/memory.py @@ -7,10 +7,7 @@ """ from abc import ABC, abstractmethod -from typing import Iterable -from typing import Optional -from typing import Union -from typing import Callable +from typing import Iterable, Optional, Union, Callable, Any class MemoryBase(ABC): @@ -21,6 +18,7 @@ class MemoryBase(ABC): def __init__( self, config: Optional[dict] = None, + **kwargs: Any, ) -> None: """MemoryBase is a base class for memory of agents. @@ -29,6 +27,7 @@ def __init__( Configuration of this memory. """ self.config = {} if config is None else config + self.kwargs = kwargs def update_config(self, config: dict) -> None: """ @@ -48,13 +47,13 @@ def get_memory( """ @abstractmethod - def add(self, memories: Union[list[dict], dict]) -> None: + def add(self, memories: Union[list[dict], dict], **kwargs: Any) -> None: """ Adding new memory fragment, depending on how the memory are stored """ @abstractmethod - def delete(self, index: Union[Iterable, int]) -> None: + def delete(self, index: Union[Iterable, int], **kwargs: Any) -> None: """ Delete memory fragment, depending on how the memory are stored and matched @@ -65,6 +64,7 @@ def load( self, memories: Union[str, dict, list], overwrite: bool = False, + **kwargs: Any, ) -> None: """ Load memory, depending on how the memory are passed, design to load @@ -76,14 +76,15 @@ def export( self, to_mem: bool = False, file_path: Optional[str] = None, + **kwargs: Any, ) -> Optional[list]: """Export memory, depending on how the memory are stored""" @abstractmethod - def clear(self) -> None: + def clear(self, **kwargs: Any) -> None: """Clean memory, depending on how the memory are stored""" @abstractmethod - def size(self) -> int: + def size(self, **kwargs: Any) -> int: """Returns the number of memory segments in memory.""" raise NotImplementedError diff --git a/src/agentscope/memory/temporary_memory.py b/src/agentscope/memory/temporary_memory.py index d95c8fad8..be987ea10 100644 --- a/src/agentscope/memory/temporary_memory.py +++ b/src/agentscope/memory/temporary_memory.py @@ -27,11 +27,16 @@ def __init__( self, config: Optional[dict] = None, embedding_model: Union[str, Callable] = None, + mem_path: Optional[str] = None, ) -> None: super().__init__(config) self._content = [] + self.mem_path = mem_path + if self.mem_path is not None: + self.load() + # prepare embedding model if needed if isinstance(embedding_model, str): self.embedding_model = load_model_by_name(embedding_model) @@ -105,6 +110,7 @@ def export( if to_mem: return self._content + file_path = file_path or self.mem_path if to_mem is False and file_path is not None: with open(file_path, "w", encoding="utf-8") as f: json.dump(self._content, f, indent=4) @@ -117,9 +123,15 @@ def export( def load( self, - memories: Union[str, dict, list], + memories: Union[str, dict, list] = None, overwrite: bool = False, ) -> None: + if memories is None: + if self.mem_path is not None: + memories = self.mem_path + else: + return + if isinstance(memories, str): if os.path.isfile(memories): with open(memories, "r", encoding="utf-8") as f: