diff --git a/docs/sphinx_doc/en/source/tutorial/201-agent.md b/docs/sphinx_doc/en/source/tutorial/201-agent.md index d28838497..1a90bf589 100644 --- a/docs/sphinx_doc/en/source/tutorial/201-agent.md +++ b/docs/sphinx_doc/en/source/tutorial/201-agent.md @@ -70,7 +70,6 @@ Below is a table summarizing the functionality of some of the key agents availab | `DialogAgent` | Manages dialogues by understanding context and generating coherent responses. | Customer service bots, virtual assistants. | | `DictDialogAgent` | Manages dialogues by understanding context and generating coherent responses, and the responses are in json format. | Customer service bots, virtual assistants. | | `UserAgent` | Interacts with the user to collect input, generating messages that may include URLs or additional specifics based on required keys. | Collecting user input for agents | -| `TextToImageAgent` | An agent that convert user input text to image. | Converting text to image | | `ReActAgent` | An agent class that implements the ReAct algorithm. | Solving complex tasks | | *More to Come* | AgentScope is continuously expanding its pool with more specialized agents for diverse applications. | | diff --git a/docs/sphinx_doc/en/source/tutorial/206-prompt.md b/docs/sphinx_doc/en/source/tutorial/206-prompt.md index 47d459527..dc98d6070 100644 --- a/docs/sphinx_doc/en/source/tutorial/206-prompt.md +++ b/docs/sphinx_doc/en/source/tutorial/206-prompt.md @@ -551,67 +551,4 @@ print(prompt) ] ``` -## Prompt Engine (Will be deprecated in the future) - -AgentScope provides the `PromptEngine` class to simplify the process of crafting -prompts for large language models (LLMs). - -## About `PromptEngine` Class - -The `PromptEngine` class provides a structured way to combine different components of a prompt, such as instructions, hints, conversation history, and user inputs, into a format that is suitable for the underlying language model. - -### Key Features of PromptEngine - -- **Model Compatibility**: It works with any `ModelWrapperBase` subclass. -- **Prompt Type**: It supports both string and list-style prompts, aligning with the model's preferred input format. - -### Initialization - -When creating an instance of `PromptEngine`, you can specify the target model and, optionally, the shrinking policy, the maximum length of the prompt, the prompt type, and a summarization model (could be the same as the target model). - -```python -model = OpenAIChatWrapper(...) -engine = PromptEngine(model) -``` - -### Joining Prompt Components - -The `join` method of `PromptEngine` provides a unified interface to handle an arbitrary number of components for constructing the final prompt. - -#### Output String Type Prompt - -If the model expects a string-type prompt, components are joined with a newline character: - -```python -system_prompt = "You're a helpful assistant." -memory = ... # can be dict, list, or string -hint_prompt = "Please respond in JSON format." - -prompt = engine.join(system_prompt, memory, hint_prompt) -# the result will be [ "You're a helpful assistant.", {"name": "user", "content": "What's the weather like today?"}] -``` - -#### Output List Type Prompt - -For models that work with list-type prompts,e.g., OpenAI and Huggingface chat models, the components can be converted to Message objects, whose type is list of dict: - -```python -system_prompt = "You're a helpful assistant." -user_messages = [{"name": "user", "content": "What's the weather like today?"}] - -prompt = engine.join(system_prompt, user_messages) -# the result should be: [{"role": "assistant", "content": "You're a helpful assistant."}, {"name": "user", "content": "What's the weather like today?"}] -``` - -#### Formatting Prompts in Dynamic Way - -The `PromptEngine` supports dynamic prompts using the `format_map` parameter, allowing you to flexibly inject various variables into the prompt components for different scenarios: - -```python -variables = {"location": "London"} -hint_prompt = "Find the weather in {location}." - -prompt = engine.join(system_prompt, user_input, hint_prompt, format_map=variables) -``` - [[Return to the top]](#206-prompt-en) diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md b/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md index 2e15490ad..01f4bf6ef 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/201-agent.md @@ -71,7 +71,6 @@ class AgentBase(Operator): | `DialogAgent` | 通过理解上下文和生成连贯的响应来管理对话。 | 客户服务机器人,虚拟助手。 | | `DictDialogAgent` | 通过理解上下文和生成连贯的响应来管理对话,返回的消息为 Json 格式。 | 客户服务机器人,虚拟助手。 | | `UserAgent` | 与用户互动以收集输入,生成可能包括URL或基于所需键的额外具体信息的消息。 | 为agent收集用户输入 | -| `TextToImageAgent` | 将用户输入的文本转化为图片 | 提供文生图功能 | | `ReActAgent` | 实现了 ReAct 算法的 Agent,能够自动调用工具处理较为复杂的任务。 | 借助工具解决复杂任务 | | *更多agent* | AgentScope 正在不断扩大agent池,加入更多专门化的agent,以适应多样化的应用。 | | diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md b/docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md index ed38bad54..12a70cb44 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/206-prompt.md @@ -485,62 +485,4 @@ print(prompt) ] ``` -## 关于`PromptEngine`类 (将会在未来版本弃用) - -`PromptEngine`类提供了一种结构化的方式来合并不同的提示组件,比如指令、提示、对话历史和用户输入,以适合底层语言模型的格式。 - -### 提示工程的关键特性 - -- **模型兼容性**:可以与任何 `ModelWrapperBase` 的子类一起工作。 -- **提示类型**:支持字符串和列表风格的提示,与模型首选的输入格式保持一致。 - -### 初始化 - -当创建 `PromptEngine` 的实例时,您可以指定目标模型,以及(可选的)缩减原则、提示的最大长度、提示类型和总结模型(可以与目标模型相同)。 - -```python -model = OpenAIChatWrapper(...) -engine = PromptEngine(model) -``` - -### 合并提示组件 - -`PromptEngine` 的 `join` 方法提供了一个统一的接口来处理任意数量的组件,以构建最终的提示。 - -#### 输出字符串类型提示 - -如果模型期望的是字符串类型的提示,组件会通过换行符连接: - -```python -system_prompt = "You're a helpful assistant." -memory = ... # 可以是字典、列表或字符串 -hint_prompt = "Please respond in JSON format." - -prompt = engine.join(system_prompt, memory, hint_prompt) -# 结果将会是 ["You're a helpful assistant.", {"name": "user", "content": "What's the weather like today?"}] -``` - -#### 输出列表类型提示 - -对于使用列表类型提示的模型,比如 OpenAI 和 Huggingface 聊天模型,组件可以转换为 `Message` 对象,其类型是字典列表: - -```python -system_prompt = "You're a helpful assistant." -user_messages = [{"name": "user", "content": "What's the weather like today?"}] - -prompt = engine.join(system_prompt, user_messages) -# 结果将会是: [{"role": "assistant", "content": "You're a helpful assistant."}, {"name": "user", "content": "What's the weather like today?"}] -``` - -#### 动态格式化提示 - -`PromptEngine` 支持使用 `format_map` 参数动态提示,允许您灵活地将各种变量注入到不同场景的提示组件中: - -```python -variables = {"location": "London"} -hint_prompt = "Find the weather in {location}." - -prompt = engine.join(system_prompt, user_input, hint_prompt, format_map=variables) -``` - [[返回顶端]](#206-prompt-zh) diff --git a/src/agentscope/agents/__init__.py b/src/agentscope/agents/__init__.py index e50efa66f..65d86b278 100644 --- a/src/agentscope/agents/__init__.py +++ b/src/agentscope/agents/__init__.py @@ -5,7 +5,6 @@ from .dialog_agent import DialogAgent from .dict_dialog_agent import DictDialogAgent from .user_agent import UserAgent -from .text_to_image_agent import TextToImageAgent from .rpc_agent import RpcAgent from .react_agent import ReActAgent from .rag_agent import LlamaIndexAgent @@ -16,7 +15,6 @@ "Operator", "DialogAgent", "DictDialogAgent", - "TextToImageAgent", "UserAgent", "ReActAgent", "DistConf", diff --git a/src/agentscope/agents/text_to_image_agent.py b/src/agentscope/agents/text_to_image_agent.py deleted file mode 100644 index f66d75b32..000000000 --- a/src/agentscope/agents/text_to_image_agent.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -"""An agent that convert text to image.""" - -from typing import Optional, Union, Sequence - -from loguru import logger - -from .agent import AgentBase -from ..message import Msg - - -class TextToImageAgent(AgentBase): - """ - A agent used to perform text to image tasks. - - TODO: change the agent into a service. - """ - - def __init__( - self, - name: str, - model_config_name: str, - use_memory: bool = True, - ) -> None: - """Initialize the text to image agent. - - Arguments: - name (`str`): - The name of the agent. - model_config_name (`str`, defaults to None): - The name of the model config, which is used to load model from - configuration. - use_memory (`bool`, defaults to `True`): - Whether the agent has memory. - """ - super().__init__( - name=name, - sys_prompt="", - model_config_name=model_config_name, - use_memory=use_memory, - ) - - logger.warning( - "The `TextToImageAgent` will be deprecated in v0.0.6, " - "please use `text_to_image` service and `ReActAgent` instead.", - ) - - def reply(self, x: Optional[Union[Msg, Sequence[Msg]]] = None) -> Msg: - if self.memory: - self.memory.add(x) - if x is None: - # get the last message from memory - if self.memory and self.memory.size() > 0: - x = self.memory.get_memory()[-1] - else: - return Msg( - self.name, - content="Please provide a text prompt to generate image.", - role="assistant", - ) - image_urls = self.model(x.content).image_urls - # TODO: optimize the construction of content - msg = Msg( - self.name, - content="This is the generated image", - role="assistant", - url=image_urls, - ) - - self.speak(msg) - - if self.memory: - self.memory.add(msg) - - return msg diff --git a/src/agentscope/file_manager.py b/src/agentscope/file_manager.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/agentscope/logging.py b/src/agentscope/logging.py index 163ba0577..951de472a 100644 --- a/src/agentscope/logging.py +++ b/src/agentscope/logging.py @@ -6,10 +6,11 @@ from loguru import logger -from .utils.tools import _guess_type_by_extension + from .message import Msg from .serialize import serialize from .studio._client import _studio_client +from .utils.common import _guess_type_by_extension from .web.gradio.utils import ( generate_image_from_name, send_msg, diff --git a/src/agentscope/manager/_file.py b/src/agentscope/manager/_file.py index 007669452..8fe93b171 100644 --- a/src/agentscope/manager/_file.py +++ b/src/agentscope/manager/_file.py @@ -8,11 +8,13 @@ import numpy as np from PIL import Image -from agentscope.utils.tools import _download_file -from agentscope.utils.tools import _hash_string -from agentscope.utils.tools import _get_timestamp -from agentscope.utils.tools import _generate_random_code -from agentscope.constants import ( +from ..utils.common import ( + _download_file, + _hash_string, + _get_timestamp, + _generate_random_code, +) +from ..constants import ( _DEFAULT_SUBDIR_CODE, _DEFAULT_SUBDIR_FILE, _DEFAULT_SUBDIR_INVOKE, @@ -32,7 +34,13 @@ def _get_text_embedding_record_hash( if isinstance(embedding_model, dict): # Format the dict to avoid duplicate keys embedding_model = json.dumps(embedding_model, sort_keys=True) - embedding_model_hash = _hash_string(embedding_model, hash_method) + elif isinstance(embedding_model, str): + embedding_model_hash = _hash_string(embedding_model, hash_method) + else: + raise RuntimeError( + f"The embedding model must be a string or a dict, got " + f"{type(embedding_model)}.", + ) # Calculate the embedding id by hashing the hash codes of the # original data and the embedding model diff --git a/src/agentscope/manager/_manager.py b/src/agentscope/manager/_manager.py index 318f2efce..fdf6e37a5 100644 --- a/src/agentscope/manager/_manager.py +++ b/src/agentscope/manager/_manager.py @@ -9,7 +9,7 @@ from ._file import FileManager from ._model import ModelManager from ..logging import LOG_LEVEL, setup_logger -from ..utils.tools import ( +from ..utils.common import ( _generate_random_code, _get_process_creation_time, _get_timestamp, diff --git a/src/agentscope/manager/_monitor.py b/src/agentscope/manager/_monitor.py index a6cad05f9..19edc7a9a 100644 --- a/src/agentscope/manager/_monitor.py +++ b/src/agentscope/manager/_monitor.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import sessionmaker from ._file import FileManager -from ..utils.tools import _is_windows +from ..utils.common import _is_windows from ..constants import ( _DEFAULT_SQLITE_DB_NAME, _DEFAULT_TABLE_NAME_FOR_CHAT_AND_EMBEDDING, diff --git a/src/agentscope/message/msg.py b/src/agentscope/message/msg.py index 342e86dda..1f3e99dd3 100644 --- a/src/agentscope/message/msg.py +++ b/src/agentscope/message/msg.py @@ -13,7 +13,7 @@ from loguru import logger from ..serialize import is_serializable -from ..utils.tools import ( +from ..utils.common import ( _map_string_to_color_mark, _get_timestamp, ) diff --git a/src/agentscope/message/placeholder.py b/src/agentscope/message/placeholder.py index 73da3d231..b657bb444 100644 --- a/src/agentscope/message/placeholder.py +++ b/src/agentscope/message/placeholder.py @@ -9,7 +9,7 @@ from .msg import Msg from ..rpc import RpcAgentClient, ResponseStub, call_in_thread from ..serialize import deserialize, is_serializable, serialize -from ..utils.tools import _is_web_url +from ..utils.common import _is_web_url class PlaceholderMessage(Msg): diff --git a/src/agentscope/models/dashscope_model.py b/src/agentscope/models/dashscope_model.py index 0058486ce..ba50b9f40 100644 --- a/src/agentscope/models/dashscope_model.py +++ b/src/agentscope/models/dashscope_model.py @@ -10,7 +10,7 @@ from ..manager import FileManager from ..message import Msg -from ..utils.tools import _convert_to_str, _guess_type_by_extension +from ..utils.common import _convert_to_str, _guess_type_by_extension try: import dashscope diff --git a/src/agentscope/models/gemini_model.py b/src/agentscope/models/gemini_model.py index e5315212b..3eaa301fb 100644 --- a/src/agentscope/models/gemini_model.py +++ b/src/agentscope/models/gemini_model.py @@ -7,9 +7,9 @@ from loguru import logger -from agentscope.message import Msg -from agentscope.models import ModelWrapperBase, ModelResponse -from agentscope.utils.tools import _convert_to_str +from ..message import Msg +from ..models import ModelWrapperBase, ModelResponse +from ..utils.common import _convert_to_str try: import google.generativeai as genai diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index 8d20a108f..429d34d7a 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -68,7 +68,7 @@ from ..manager import FileManager from ..manager import MonitorManager from ..message import Msg -from ..utils.tools import _get_timestamp, _convert_to_str +from ..utils.common import _get_timestamp, _convert_to_str from ..constants import _DEFAULT_MAX_RETRIES from ..constants import _DEFAULT_RETRY_INTERVAL diff --git a/src/agentscope/models/ollama_model.py b/src/agentscope/models/ollama_model.py index 7d65cafd0..ec87f219f 100644 --- a/src/agentscope/models/ollama_model.py +++ b/src/agentscope/models/ollama_model.py @@ -3,9 +3,9 @@ from abc import ABC from typing import Sequence, Any, Optional, List, Union, Generator -from agentscope.message import Msg -from agentscope.models import ModelWrapperBase, ModelResponse -from agentscope.utils.tools import _convert_to_str +from ..message import Msg +from ..models import ModelWrapperBase, ModelResponse +from ..utils.common import _convert_to_str try: import ollama diff --git a/src/agentscope/models/openai_model.py b/src/agentscope/models/openai_model.py index 0a87ae381..772b43c09 100644 --- a/src/agentscope/models/openai_model.py +++ b/src/agentscope/models/openai_model.py @@ -21,7 +21,7 @@ from .model import ModelWrapperBase, ModelResponse from ..manager import FileManager from ..message import Msg -from ..utils.tools import _convert_to_str, _to_openai_image_url +from ..utils.common import _convert_to_str, _to_openai_image_url from ..utils.token_utils import get_openai_max_length diff --git a/src/agentscope/models/response.py b/src/agentscope/models/response.py index 3019257e0..60a140ca3 100644 --- a/src/agentscope/models/response.py +++ b/src/agentscope/models/response.py @@ -3,7 +3,7 @@ import json from typing import Optional, Sequence, Any, Generator, Union, Tuple -from agentscope.utils.tools import _is_json_serializable +from ..utils.common import _is_json_serializable class ModelResponse: @@ -56,6 +56,11 @@ def text(self) -> str: self._text += chunk return self._text + @text.setter + def text(self, value: str) -> None: + """Set the text field.""" + self._text = value + @property def stream(self) -> Union[None, Generator[Tuple[bool, str], None, None]]: """Return the stream generator if it exists.""" diff --git a/src/agentscope/parsers/json_object_parser.py b/src/agentscope/parsers/json_object_parser.py index 970828639..441af8286 100644 --- a/src/agentscope/parsers/json_object_parser.py +++ b/src/agentscope/parsers/json_object_parser.py @@ -8,16 +8,16 @@ from loguru import logger from pydantic import BaseModel -from agentscope.exception import ( +from ..exception import ( TagNotFoundError, JsonParsingError, JsonTypeError, RequiredFieldNotFoundError, ) -from agentscope.models import ModelResponse -from agentscope.parsers import ParserBase -from agentscope.parsers.parser_base import DictFilterMixin -from agentscope.utils.tools import _join_str_with_comma_and +from ..models import ModelResponse +from ..parsers import ParserBase +from ..parsers.parser_base import DictFilterMixin +from ..utils.common import _join_str_with_comma_and class MarkdownJsonObjectParser(ParserBase): @@ -166,9 +166,9 @@ def __init__( self, content_hint: Optional[Any] = None, required_keys: List[str] = None, - keys_to_memory: Optional[Union[str, bool, Sequence[str]]] = True, - keys_to_content: Optional[Union[str, bool, Sequence[str]]] = True, - keys_to_metadata: Optional[Union[str, bool, Sequence[str]]] = False, + keys_to_memory: Union[str, bool, Sequence[str]] = True, + keys_to_content: Union[str, bool, Sequence[str]] = True, + keys_to_metadata: Union[str, bool, Sequence[str]] = False, ) -> None: """Initialize the parser with the content hint. diff --git a/src/agentscope/prompt/__init__.py b/src/agentscope/prompt/__init__.py index 1fb694ff9..dcd15d4b3 100644 --- a/src/agentscope/prompt/__init__.py +++ b/src/agentscope/prompt/__init__.py @@ -6,11 +6,9 @@ from ._prompt_generator_en import EnglishSystemPromptGenerator from ._prompt_comparer import SystemPromptComparer from ._prompt_optimizer import SystemPromptOptimizer -from ._prompt_engine import PromptEngine __all__ = [ - "PromptEngine", "SystemPromptGeneratorBase", "ChineseSystemPromptGenerator", "EnglishSystemPromptGenerator", diff --git a/src/agentscope/prompt/_prompt_engine.py b/src/agentscope/prompt/_prompt_engine.py deleted file mode 100644 index 8d66a16f5..000000000 --- a/src/agentscope/prompt/_prompt_engine.py +++ /dev/null @@ -1,179 +0,0 @@ -# -*- coding: utf-8 -*- -"""Prompt engineering module.""" -from typing import Any, Optional, Union -from enum import IntEnum - -from loguru import logger - -from agentscope.models import OpenAIWrapperBase, ModelWrapperBase -from agentscope.constants import ShrinkPolicy -from agentscope.utils.tools import to_openai_dict, to_dialog_str - - -class PromptType(IntEnum): - """Enum for prompt types.""" - - STRING = 0 - LIST = 1 - - -class PromptEngine: - """Prompt engineering module for both list and string prompt""" - - def __init__( - self, - model: ModelWrapperBase, - shrink_policy: ShrinkPolicy = ShrinkPolicy.TRUNCATE, - max_length: Optional[int] = None, - prompt_type: Optional[PromptType] = None, - max_summary_length: int = 200, - summarize_model: Optional[ModelWrapperBase] = None, - ) -> None: - """Init PromptEngine. - - Args: - model (`ModelWrapperBase`): - The target model for prompt engineering. - shrink_policy (`ShrinkPolicy`, defaults to - `ShrinkPolicy.TRUNCATE`): - The shrink policy for prompt engineering, defaults to - `ShrinkPolicy.TRUNCATE`. - max_length (`Optional[int]`, defaults to `None`): - The max length of context, if it is None, it will be set to the - max length of the model. - prompt_type (`Optional[MsgType]`, defaults to `None`): - The type of prompt, if it is None, it will be set according to - the model. - max_summary_length (`int`, defaults to `200`): - The max length of summary, if it is None, it will be set to the - max length of the model. - summarize_model (`Optional[ModelWrapperBase]`, defaults to `None`): - The model used for summarization, if it is None, it will be - set to `model`. - - Note: - - 1. TODO: Shrink function is still under development. - - 2. If the argument `max_length` and `prompt_type` are not given, - they will be set according to the given model. - - 3. `shrink_policy` is used when the prompt is too long, it can - be set to `ShrinkPolicy.TRUNCATE` or `ShrinkPolicy.SUMMARIZE`. - - a. `ShrinkPolicy.TRUNCATE` will truncate the prompt to the - desired length. - - b. `ShrinkPolicy.SUMMARIZE` will summarize partial of the - dialog history to save space. The summarization model - defaults to `model` if not given. - - Example: - - With prompt engine, we encapsulate different operations for - string- and list-style prompt, and block the prompt engineering - process from the user. - As a user, you can just combine you prompt as follows. - - .. code-block:: python - - # prepare the component - system_prompt = "You're a helpful assistant ..." - hint_prompt = "You should response in Json format." - prefix = "assistant: " - - # initialize the prompt engine and join the prompt - engine = PromptEngine(model) - prompt = engine.join(system_prompt, memory.get_memory(), - hint_prompt, prefix) - """ - self.model = model - self.shrink_policy = shrink_policy - self.max_length = max_length - - if prompt_type is None: - if isinstance(model, OpenAIWrapperBase): - self.prompt_type = PromptType.LIST - else: - self.prompt_type = PromptType.STRING - else: - self.prompt_type = prompt_type - - self.max_summary_length = max_summary_length - - if summarize_model is None: - self.summarize_model = model - - logger.warning( - "The prompt engine will be deprecated in the future. " - "Please use the `format` function in model wrapper object " - "instead. More details refer to ", - "https://modelscope.github.io/agentscope/en/tutorial/206-prompt" - ".html", - ) - - def join( - self, - *args: Any, - format_map: Optional[dict] = None, - ) -> Union[str, list[dict]]: - """Join prompt components according to its type. The join function can - accept any number and type of arguments. If prompt type is - `PromptType.STRING`, the arguments will be joined by `"\\\\n"`. If - prompt type is `PromptType.LIST`, the string arguments will be - converted to `Msg` from `system`. - """ - # TODO: achieve the summarize function - - # Filter `None` - args = [_ for _ in args if _ is not None] - - if self.prompt_type == PromptType.STRING: - return self.join_to_str(*args, format_map=format_map) - elif self.prompt_type == PromptType.LIST: - return self.join_to_list(*args, format_map=format_map) - else: - raise RuntimeError("Invalid prompt type.") - - def join_to_str(self, *args: Any, format_map: Union[dict, None]) -> str: - """Join prompt components to a string.""" - prompt = [] - for item in args: - if isinstance(item, list): - items_str = self.join_to_str(*item, format_map=None) - prompt += [items_str] - elif isinstance(item, dict): - prompt.append(to_dialog_str(item)) - else: - prompt.append(str(item)) - prompt_str = "\n".join(prompt) - - if format_map is not None: - prompt_str = prompt_str.format_map(format_map) - - return prompt_str - - def join_to_list(self, *args: Any, format_map: Union[dict, None]) -> list: - """Join prompt components to a list of `Msg` objects.""" - prompt = [] - for item in args: - if isinstance(item, list): - # nested processing - prompt.extend(self.join_to_list(*item, format_map=None)) - elif isinstance(item, dict): - prompt.append(to_openai_dict(item)) - else: - prompt.append(to_openai_dict({"content": str(item)})) - - if format_map is not None: - format_prompt = [] - for msg in prompt: - format_prompt.append( - { - k.format_map(format_map): v.format_map(format_map) - for k, v in msg.items() - }, - ) - prompt = format_prompt - - return prompt diff --git a/src/agentscope/rpc/__init__.py b/src/agentscope/rpc/__init__.py index 42d3b5fe5..2f061c85f 100644 --- a/src/agentscope/rpc/__init__.py +++ b/src/agentscope/rpc/__init__.py @@ -8,7 +8,7 @@ from .rpc_agent_pb2_grpc import RpcAgentStub from .rpc_agent_pb2_grpc import add_RpcAgentServicer_to_server except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter RpcMsg = ImportErrorReporter(import_error, "distribute") # type: ignore[misc] RpcAgentServicer = ImportErrorReporter(import_error, "distribute") diff --git a/src/agentscope/rpc/rpc_agent_client.py b/src/agentscope/rpc/rpc_agent_client.py index 4e4bdbe45..480bbafbc 100644 --- a/src/agentscope/rpc/rpc_agent_client.py +++ b/src/agentscope/rpc/rpc_agent_client.py @@ -18,7 +18,7 @@ from agentscope.rpc.rpc_agent_pb2_grpc import RpcAgentStub import agentscope.rpc.rpc_agent_pb2 as agent_pb2 except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter dill = ImportErrorReporter(import_error, "distribute") grpc = ImportErrorReporter(import_error, "distribute") @@ -26,7 +26,7 @@ RpcAgentStub = ImportErrorReporter(import_error, "distribute") RpcError = ImportError -from ..utils.tools import generate_id_from_seed +from ..utils.common import _generate_id_from_seed from ..exception import AgentServerNotAliveError from ..constants import _DEFAULT_RPC_OPTIONS from ..exception import AgentCallError @@ -333,7 +333,7 @@ def download_file(self, path: str) -> str: file_manager = FileManager.get_instance() local_filename = ( - f"{generate_id_from_seed(path, 5)}_{os.path.basename(path)}" + f"{_generate_id_from_seed(path, 5)}_{os.path.basename(path)}" ) def _generator() -> Generator[bytes, None, None]: diff --git a/src/agentscope/rpc/rpc_agent_pb2_grpc.py b/src/agentscope/rpc/rpc_agent_pb2_grpc.py index 0234d55f2..1c506c176 100644 --- a/src/agentscope/rpc/rpc_agent_pb2_grpc.py +++ b/src/agentscope/rpc/rpc_agent_pb2_grpc.py @@ -5,7 +5,7 @@ import grpc from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter grpc = ImportErrorReporter(import_error, "distribute") google_dot_protobuf_dot_empty__pb2 = ImportErrorReporter( diff --git a/src/agentscope/server/launcher.py b/src/agentscope/server/launcher.py index f65ced242..0b826c835 100644 --- a/src/agentscope/server/launcher.py +++ b/src/agentscope/server/launcher.py @@ -18,7 +18,7 @@ add_RpcAgentServicer_to_server, ) except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter grpc = ImportErrorReporter(import_error, "distribute") add_RpcAgentServicer_to_server = ImportErrorReporter( @@ -29,7 +29,7 @@ from ..server.servicer import AgentServerServicer from ..manager import ASManager from ..agents.agent import AgentBase -from ..utils.tools import check_port, generate_id_from_seed +from ..utils.common import _check_port, _generate_id_from_seed from ..constants import _DEFAULT_RPC_OPTIONS @@ -191,7 +191,7 @@ async def shutdown_signal_handler() -> None: ) while True: try: - port = check_port(port) + port = _check_port(port) servicer.port = port server = grpc.aio.server( futures.ThreadPoolExecutor(max_workers=None), @@ -333,7 +333,7 @@ def __init__( The url of the agentscope studio. """ self.host = host - self.port = check_port(port) + self.port = _check_port(port) self.max_pool_size = max_pool_size self.max_timeout_seconds = max_timeout_seconds self.local_mode = local_mode @@ -354,7 +354,7 @@ def __init__( @classmethod def generate_server_id(cls, host: str, port: int) -> str: """Generate server id""" - return generate_id_from_seed(f"{host}:{port}:{time.time()}", length=8) + return _generate_id_from_seed(f"{host}:{port}:{time.time()}", length=8) def _launch_in_main(self) -> None: """Launch agent server in main-process""" diff --git a/src/agentscope/server/servicer.py b/src/agentscope/server/servicer.py index 1404d9adc..154c04b70 100644 --- a/src/agentscope/server/servicer.py +++ b/src/agentscope/server/servicer.py @@ -18,7 +18,7 @@ from google.protobuf.empty_pb2 import Empty from expiringdict import ExpiringDict except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter dill = ImportErrorReporter(import_error, "distribute") psutil = ImportErrorReporter(import_error, "distribute") diff --git a/src/agentscope/service/execute_code/exec_notebook.py b/src/agentscope/service/execute_code/exec_notebook.py index bbd697121..f296c41b0 100644 --- a/src/agentscope/service/execute_code/exec_notebook.py +++ b/src/agentscope/service/execute_code/exec_notebook.py @@ -13,7 +13,7 @@ from nbclient.exceptions import CellTimeoutError, DeadKernelError import nbformat except ImportError as import_error: - from agentscope.utils.tools import ImportErrorReporter + from agentscope.utils.common import ImportErrorReporter nbclient = ImportErrorReporter(import_error) nbformat = ImportErrorReporter(import_error) diff --git a/src/agentscope/service/execute_code/exec_python.py b/src/agentscope/service/execute_code/exec_python.py index c2491f3eb..2cde33740 100644 --- a/src/agentscope/service/execute_code/exec_python.py +++ b/src/agentscope/service/execute_code/exec_python.py @@ -27,10 +27,10 @@ except (ModuleNotFoundError, ImportError): resource = None -from agentscope.utils.common import create_tempdir, timer -from agentscope.service.service_status import ServiceExecStatus -from agentscope.service.service_response import ServiceResponse -from agentscope.constants import ( +from ...utils.common import create_tempdir, timer +from ..service_status import ServiceExecStatus +from ..service_response import ServiceResponse +from ...constants import ( _DEFAULT_PYPI_MIRROR, _DEFAULT_TRUSTED_HOST, ) diff --git a/src/agentscope/service/file/text.py b/src/agentscope/service/file/text.py index 725d08a56..e0e031b0d 100644 --- a/src/agentscope/service/file/text.py +++ b/src/agentscope/service/file/text.py @@ -2,7 +2,6 @@ """ Operators for txt file and directory. """ import os -from agentscope.utils.common import write_file from agentscope.service.service_response import ServiceResponse from agentscope.service.service_status import ServiceExecStatus @@ -59,4 +58,17 @@ def write_text_file( status=ServiceExecStatus.ERROR, content="FileExistsError: The file already exists.", ) - return write_file(content, file_path) + + try: + with open(file_path, "w", encoding="utf-8") as file: + file.write(content) + return ServiceResponse( + status=ServiceExecStatus.SUCCESS, + content="Success", + ) + except Exception as e: + error_message = f"{e.__class__.__name__}: {e}" + return ServiceResponse( + status=ServiceExecStatus.ERROR, + content=error_message, + ) diff --git a/src/agentscope/service/multi_modality/dashscope_services.py b/src/agentscope/service/multi_modality/dashscope_services.py index 04774f588..d3963bbc7 100644 --- a/src/agentscope/service/multi_modality/dashscope_services.py +++ b/src/agentscope/service/multi_modality/dashscope_services.py @@ -20,11 +20,11 @@ # SpeechSynthesizerWrapper is current not available -from agentscope.service.service_response import ( +from ..service_response import ( ServiceResponse, ServiceExecStatus, ) -from agentscope.utils.tools import _download_file +from ...utils.common import _download_file def dashscope_text_to_image( diff --git a/src/agentscope/service/multi_modality/openai_services.py b/src/agentscope/service/multi_modality/openai_services.py index 16aca5a58..b5fd799b1 100644 --- a/src/agentscope/service/multi_modality/openai_services.py +++ b/src/agentscope/service/multi_modality/openai_services.py @@ -13,21 +13,16 @@ import requests -from openai import OpenAI -from openai.types import ImagesResponse -from openai._types import NOT_GIVEN, NotGiven -from agentscope.service.service_response import ( +from ..service_response import ( ServiceResponse, ServiceExecStatus, ) -from agentscope.models.openai_model import ( +from ...models.openai_model import ( OpenAIDALLEWrapper, OpenAIChatWrapper, ) -from agentscope.utils.tools import _download_file - - -from agentscope.message import Msg +from ...utils.common import _download_file +from ...message import Msg def _url_to_filename(url: str) -> str: @@ -52,11 +47,10 @@ def _url_to_filename(url: str) -> str: def _handle_openai_img_response( - response: ImagesResponse, + raw_response: dict, save_dir: Optional[str] = None, ) -> Union[str, Sequence[str]]: """Handle the response from OpenAI image generation API.""" - raw_response = response.model_dump() if "data" not in raw_response: if "error" in raw_response: error_msg = raw_response["error"]["message"] @@ -278,19 +272,32 @@ def openai_edit_image( 'EDITED_IMAGE_URL2']} > } """ - client = OpenAI(api_key=api_key) + try: + import openai + except ImportError as e: + raise ImportError( + "The `openai` library is not installed. Please install it by " + "running `pip install openai`.", + ) from e + + client = openai.OpenAI(api_key=api_key) # _parse_url handles both local and web URLs and returns BytesIO image = _parse_url(image_url) try: - response = client.images.edit( - model="dall-e-2", - image=image, - mask=_parse_url(mask_url) if mask_url else NOT_GIVEN, - prompt=prompt, - n=n, - size=size, - ) - urls = _handle_openai_img_response(response, save_dir) + kwargs = { + "model": "dall-e-2", + "image": image, + "prompt": prompt, + "n": n, + "size": size, + } + + if mask_url: + kwargs["mask"] = _parse_url(mask_url) + + response = client.images.edit(**kwargs) + + urls = _handle_openai_img_response(response.model_dump(), save_dir) return ServiceResponse( ServiceExecStatus.SUCCESS, {"image_urls": urls}, @@ -352,7 +359,15 @@ def openai_create_image_variation( > 'content': {'image_urls': ['VARIATION_URL1', 'VARIATION_URL2']} > } """ - client = OpenAI(api_key=api_key) + try: + import openai + except ImportError as e: + raise ImportError( + "The `openai` library is not installed. Please install it by " + "running `pip install openai`.", + ) from e + + client = openai.OpenAI(api_key=api_key) # _parse_url handles both local and web URLs and returns BytesIO image = _parse_url(image_url) try: @@ -362,7 +377,7 @@ def openai_create_image_variation( n=n, size=size, ) - urls = _handle_openai_img_response(response, save_dir) + urls = _handle_openai_img_response(response.model_dump(), save_dir) return ServiceResponse( ServiceExecStatus.SUCCESS, {"image_urls": urls}, @@ -375,7 +390,7 @@ def openai_create_image_variation( def openai_image_to_text( - image_urls: Union[str, Sequence[str]], + image_urls: Union[str, list[str]], api_key: str, prompt: str = "Describe the image", model: Literal["gpt-4o", "gpt-4-turbo"] = "gpt-4o", @@ -385,7 +400,7 @@ def openai_image_to_text( return the generated text. Args: - image_urls (`Union[str, Sequence[str]]`): + image_urls (`Union[str, list[str]]`): The URL or list of URLs pointing to the images that need to be described. api_key (`str`): @@ -502,7 +517,15 @@ def openai_text_to_audio( > 'content': {'audio_path': './audio_files/Hello,_welco.mp3'} > } """ - client = OpenAI(api_key=api_key) + try: + import openai + except ImportError as e: + raise ImportError( + "The `openai` library is not installed. Please install it by " + "running `pip install openai`.", + ) from e + + client = openai.OpenAI(api_key=api_key) save_name = _audio_filename(text) if os.path.isabs(save_dir): save_path = os.path.join(save_dir, f"{save_name}.{res_format}") @@ -535,7 +558,7 @@ def openai_text_to_audio( def openai_audio_to_text( audio_file_url: str, api_key: str, - language: Union[str, NotGiven] = NOT_GIVEN, + language: str = "en", temperature: float = 0.2, ) -> ServiceResponse: """ @@ -547,9 +570,10 @@ def openai_audio_to_text( transcribed. api_key (`str`): The API key for the OpenAI API. - language (`Union[str, NotGiven]`, defaults to `NotGiven()`): - The language of the audio. If not specified, the language will - be auto-detected. + language (`str`, defaults to `"en"`): + The language of the input audio. Supplying the input language in + [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) + format will improve accuracy and latency. temperature (`float`, defaults to `0.2`): The temperature for the transcription, which affects the randomness of the output. @@ -575,7 +599,15 @@ def openai_audio_to_text( the audio file.'} > } """ - client = OpenAI(api_key=api_key) + try: + import openai + except ImportError as e: + raise ImportError( + "The `openai` library is not installed. Please install it by " + "running `pip install openai`.", + ) from e + + client = openai.OpenAI(api_key=api_key) audio_file_url = os.path.abspath(audio_file_url) with open(audio_file_url, "rb") as audio_file: try: diff --git a/src/agentscope/service/web/dblp.py b/src/agentscope/service/web/dblp.py index 7d6ab9c1c..91ed9aac8 100644 --- a/src/agentscope/service/web/dblp.py +++ b/src/agentscope/service/web/dblp.py @@ -7,7 +7,7 @@ ServiceResponse, ServiceExecStatus, ) -from agentscope.utils.common import requests_get +from ...utils.common import _requests_get def dblp_search_publications( @@ -92,7 +92,7 @@ def dblp_search_publications( "f": start, "c": num_completion, } - search_results = requests_get(url, params) + search_results = _requests_get(url, params) if isinstance(search_results, str): return ServiceResponse(ServiceExecStatus.ERROR, search_results) @@ -204,7 +204,7 @@ def dblp_search_authors( "f": start, "c": num_completion, } - search_results = requests_get(url, params) + search_results = _requests_get(url, params) if isinstance(search_results, str): return ServiceResponse(ServiceExecStatus.ERROR, search_results) hits = search_results.get("result", {}).get("hits", {}).get("hit", []) @@ -297,7 +297,7 @@ def dblp_search_venues( "f": start, "c": num_completion, } - search_results = requests_get(url, params) + search_results = _requests_get(url, params) if isinstance(search_results, str): return ServiceResponse(ServiceExecStatus.ERROR, search_results) diff --git a/src/agentscope/service/web/search.py b/src/agentscope/service/web/search.py index b5ff7e59f..c748a3cbc 100644 --- a/src/agentscope/service/web/search.py +++ b/src/agentscope/service/web/search.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """Search question in the web""" from typing import Any -from agentscope.service.service_response import ServiceResponse -from agentscope.utils.common import requests_get -from agentscope.service.service_status import ServiceExecStatus +from ..service_response import ServiceResponse +from ...utils.common import _requests_get +from ..service_status import ServiceExecStatus def bing_search( @@ -85,7 +85,7 @@ def bing_search( headers = {"Ocp-Apim-Subscription-Key": api_key} - search_results = requests_get( + search_results = _requests_get( bing_search_url, params, headers, @@ -173,7 +173,7 @@ def google_search( if kwargs: params.update(**kwargs) - search_results = requests_get(google_search_url, params) + search_results = _requests_get(google_search_url, params) if isinstance(search_results, str): return ServiceResponse(ServiceExecStatus.ERROR, search_results) diff --git a/src/agentscope/studio/_app.py b/src/agentscope/studio/_app.py index 06d3f7762..81ed58b61 100644 --- a/src/agentscope/studio/_app.py +++ b/src/agentscope/studio/_app.py @@ -33,7 +33,7 @@ FILE_COUNT_LIMIT, ) from ._studio_utils import _check_and_convert_id_type -from ..utils.tools import ( +from ..utils.common import ( _is_process_alive, _is_windows, _generate_new_runtime_id, diff --git a/src/agentscope/studio/static/html-drag-components/agent-texttoimageagent.html b/src/agentscope/studio/static/html-drag-components/agent-texttoimageagent.html deleted file mode 100644 index d3ff12c51..000000000 --- a/src/agentscope/studio/static/html-drag-components/agent-texttoimageagent.html +++ /dev/null @@ -1,28 +0,0 @@ -
-
-
- - - - - TextToImageAgent -
- - -
-
-
- Agent for text to image generation -
Node ID: ID_PLACEHOLDER
-
- - - -
- - - -
-
\ No newline at end of file diff --git a/src/agentscope/studio/static/js/workstation.js b/src/agentscope/studio/static/js/workstation.js index 3323b55b1..2c35adcad 100644 --- a/src/agentscope/studio/static/js/workstation.js +++ b/src/agentscope/studio/static/js/workstation.js @@ -20,7 +20,6 @@ let nameToHtmlFile = { 'Message': 'message-msg.html', 'DialogAgent': 'agent-dialogagent.html', 'UserAgent': 'agent-useragent.html', - 'TextToImageAgent': 'agent-texttoimageagent.html', 'DictDialogAgent': 'agent-dictdialogagent.html', 'ReActAgent': 'agent-reactagent.html', 'Placeholder': 'pipeline-placeholder.html', @@ -605,22 +604,6 @@ async function addNodeToDrawFlow(name, pos_x, pos_y) { } break; - case 'TextToImageAgent': - const TextToImageAgentID = - editor.addNode('TextToImageAgent', 1, - 1, pos_x, pos_y, - 'TextToImageAgent', { - "args": { - "name": '', - "model_config_name": '' - } - }, htmlSourceCode); - var nodeElement = document.querySelector(`#node-${TextToImageAgentID} .node-id`); - if (nodeElement) { - nodeElement.textContent = TextToImageAgentID; - } - break; - case 'DictDialogAgent': const DictDialogAgentID = editor.addNode('DictDialogAgent', 1, 1, pos_x, pos_y, diff --git a/src/agentscope/studio/templates/workstation.html b/src/agentscope/studio/templates/workstation.html index cd1897f48..9685bccab 100644 --- a/src/agentscope/studio/templates/workstation.html +++ b/src/agentscope/studio/templates/workstation.html @@ -146,11 +146,6 @@ draggable="true" ondragstart="drag(event)"> UserAgent -
  • - TextToImageAgent -
  • diff --git a/src/agentscope/test.py b/src/agentscope/test.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/agentscope/utils/common.py b/src/agentscope/utils/common.py index b7ebe3a15..372d9ca66 100644 --- a/src/agentscope/utils/common.py +++ b/src/agentscope/utils/common.py @@ -1,18 +1,25 @@ # -*- coding: utf-8 -*- """ Common utils.""" - +import base64 import contextlib +import datetime +import hashlib +import json import os +import random import re +import secrets import signal +import socket +import string import sys import tempfile import threading -from typing import Any, Generator, Optional, Union -import requests +from typing import Any, Generator, Optional, Union, Tuple, Literal, List +from urllib.parse import urlparse -from agentscope.service.service_response import ServiceResponse -from agentscope.service.service_status import ServiceExecStatus +import psutil +import requests @contextlib.contextmanager @@ -59,12 +66,12 @@ def create_tempdir() -> Generator: https://github.com/openai/human-eval/blob/master/human_eval/execution.py """ with tempfile.TemporaryDirectory() as dirname: - with chdir(dirname): + with _chdir(dirname): yield dirname @contextlib.contextmanager -def chdir(path: str) -> Generator: +def _chdir(path: str) -> Generator: """ A context manager that changes the current working directory to the given path. @@ -84,44 +91,7 @@ def chdir(path: str) -> Generator: os.chdir(cwd) -def write_file(content: str, file_path: str) -> ServiceResponse: - """ - Write content to a file. - - Args: - content (str): The content to be written to the file. - file_path (str): The path to the file where the content will be - written. - - Returns: - ServiceResponse: where the boolean indicates the success of the - operation, and the str contains an empty string if successful or an - error message if any, including the error type. - - This function attempts to open the file in write mode and write the - provided content to it. If the file does not exist, it will be created. - If the file exists, its content will be overwritten. If a - PermissionError occurs, indicating a lack of necessary permissions, - or an IOError occurs, signaling additional issues such as an invalid - file path or hardware-related I/O error, the function will catch the - exception and return `False` along with the error message. - """ - try: - with open(file_path, "w", encoding="utf-8") as file: - file.write(content) - return ServiceResponse( - status=ServiceExecStatus.SUCCESS, - content="Success", - ) - except Exception as e: - error_message = f"{e.__class__.__name__}: {e}" - return ServiceResponse( - status=ServiceExecStatus.ERROR, - content=error_message, - ) - - -def requests_get( +def _requests_get( url: str, params: dict, headers: Optional[dict] = None, @@ -178,3 +148,452 @@ def _if_change_database(sql_query: str) -> bool: if pattern_unsafe_sql.search(sql_query): return False return True + + +def _get_timestamp( + format_: str = "%Y-%m-%d %H:%M:%S", + time: datetime.datetime = None, +) -> str: + """Get current timestamp.""" + if time is None: + return datetime.datetime.now().strftime(format_) + else: + return time.strftime(format_) + + +def to_openai_dict(item: dict) -> dict: + """Convert `Msg` to `dict` for OpenAI API.""" + clean_dict = {} + + if "name" in item: + clean_dict["name"] = item["name"] + + if "role" in item: + clean_dict["role"] = item["role"] + else: + clean_dict["role"] = "assistant" + + if "content" in item: + clean_dict["content"] = _convert_to_str(item["content"]) + else: + raise ValueError("The content of the message is missing.") + + return clean_dict + + +def _find_available_port() -> int: + """Get an unoccupied socket port number.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _check_port(port: Optional[int] = None) -> int: + """Check if the port is available. + + Args: + port (`int`): + the port number being checked. + + Returns: + `int`: the port number that passed the check. If the port is found + to be occupied, an available port number will be automatically + returned. + """ + if port is None: + new_port = _find_available_port() + return new_port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + if s.connect_ex(("localhost", port)) == 0: + raise RuntimeError("Port is occupied.") + except Exception: + new_port = _find_available_port() + return new_port + return port + + +def _guess_type_by_extension( + url: str, +) -> Literal["image", "audio", "video", "file"]: + """Guess the type of the file by its extension.""" + extension = url.split(".")[-1].lower() + + if extension in [ + "bmp", + "dib", + "icns", + "ico", + "jfif", + "jpe", + "jpeg", + "jpg", + "j2c", + "j2k", + "jp2", + "jpc", + "jpf", + "jpx", + "apng", + "png", + "bw", + "rgb", + "rgba", + "sgi", + "tif", + "tiff", + "webp", + ]: + return "image" + elif extension in [ + "amr", + "wav", + "3gp", + "3gpp", + "aac", + "mp3", + "flac", + "ogg", + ]: + return "audio" + elif extension in [ + "mp4", + "webm", + "mkv", + "flv", + "avi", + "mov", + "wmv", + "rmvb", + ]: + return "video" + else: + return "file" + + +def _to_openai_image_url(url: str) -> str: + """Convert an image url to openai format. If the given url is a local + file, it will be converted to base64 format. Otherwise, it will be + returned directly. + + Args: + url (`str`): + The local or public url of the image. + """ + # See https://platform.openai.com/docs/guides/vision for details of + # support image extensions. + support_image_extensions = ( + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ) + + parsed_url = urlparse(url) + + lower_url = url.lower() + + # Web url + if parsed_url.scheme != "": + if any(lower_url.endswith(_) for _ in support_image_extensions): + return url + + # Check if it is a local file + elif os.path.exists(url) and os.path.isfile(url): + if any(lower_url.endswith(_) for _ in support_image_extensions): + with open(url, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode( + "utf-8", + ) + extension = parsed_url.path.lower().split(".")[-1] + mime_type = f"image/{extension}" + return f"data:{mime_type};base64,{base64_image}" + + raise TypeError(f"{url} should be end with {support_image_extensions}.") + + +def _download_file(url: str, path_file: str, max_retries: int = 3) -> bool: + """Download file from the given url and save it to the given path. + + Args: + url (`str`): + The url of the file. + path_file (`str`): + The path to save the file. + max_retries (`int`, defaults to `3`) + The maximum number of retries when fail to download the file. + """ + for n_retry in range(1, max_retries + 1): + response = requests.get(url, stream=True) + if response.status_code == requests.codes.ok: + with open(path_file, "wb") as file: + for chunk in response.iter_content(1024): + file.write(chunk) + return True + else: + raise RuntimeError( + f"Failed to download file from {url} (status code: " + f"{response.status_code}). Retry {n_retry}/{max_retries}.", + ) + return False + + +def _generate_random_code( + length: int = 6, + uppercase: bool = True, + lowercase: bool = True, + digits: bool = True, +) -> str: + """Get random code.""" + characters = "" + if uppercase: + characters += string.ascii_uppercase + if lowercase: + characters += string.ascii_lowercase + if digits: + characters += string.digits + return "".join(secrets.choice(characters) for i in range(length)) + + +def _generate_id_from_seed(seed: str, length: int = 8) -> str: + """Generate random id from seed str. + + Args: + seed (`str`): seed string. + length (`int`): generated id length. + """ + hasher = hashlib.sha256() + hasher.update(seed.encode("utf-8")) + hash_digest = hasher.hexdigest() + + random.seed(hash_digest) + id_chars = [ + random.choice(string.ascii_letters + string.digits) + for _ in range(length) + ] + return "".join(id_chars) + + +def _is_web_url(url: str) -> bool: + """Whether the url is accessible from the Web. + + Args: + url (`str`): + The url to check. + + Note: + This function is not perfect, it only checks if the URL starts with + common web protocols, e.g., http, https, ftp, oss. + """ + parsed_url = urlparse(url) + return parsed_url.scheme in ["http", "https", "ftp", "oss"] + + +def _is_json_serializable(obj: Any) -> bool: + """Check if the given object is json serializable.""" + try: + json.dumps(obj) + return True + except TypeError: + return False + + +def _convert_to_str(content: Any) -> str: + """Convert the content to string. + + Note: + For prompt engineering, simply calling `str(content)` or + `json.dumps(content)` is not enough. + + - For `str(content)`, if `content` is a dictionary, it will turn double + quotes to single quotes. When this string is fed into prompt, the LLMs + may learn to use single quotes instead of double quotes (which + cannot be loaded by `json.loads` API). + + - For `json.dumps(content)`, if `content` is a string, it will add + double quotes to the string. LLMs may learn to use double quotes to + wrap strings, which leads to the same issue as `str(content)`. + + To avoid these issues, we use this function to safely convert the + content to a string used in prompt. + + Args: + content (`Any`): + The content to be converted. + + Returns: + `str`: The converted string. + """ + + if isinstance(content, str): + return content + elif isinstance(content, (dict, list, int, float, bool, tuple)): + return json.dumps(content, ensure_ascii=False) + else: + return str(content) + + +def _join_str_with_comma_and(elements: List[str]) -> str: + """Return the JSON string with comma, and use " and " between the last two + elements.""" + + if len(elements) == 0: + return "" + elif len(elements) == 1: + return elements[0] + elif len(elements) == 2: + return " and ".join(elements) + else: + return ", ".join(elements[:-1]) + f", and {elements[-1]}" + + +class ImportErrorReporter: + """Used as a placeholder for missing packages. + When called, an ImportError will be raised, prompting the user to install + the specified extras requirement. + """ + + def __init__(self, error: ImportError, extras_require: str = None) -> None: + """Init the ImportErrorReporter. + + Args: + error (`ImportError`): the original ImportError. + extras_require (`str`): the extras requirement. + """ + self.error = error + self.extras_require = extras_require + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self._raise_import_error() + + def __getattr__(self, name: str) -> Any: + return self._raise_import_error() + + def __getitem__(self, __key: Any) -> Any: + return self._raise_import_error() + + def _raise_import_error(self) -> Any: + """Raise the ImportError""" + err_msg = f"ImportError occorred: [{self.error.msg}]." + if self.extras_require is not None: + err_msg += ( + f" Please install [{self.extras_require}] version" + " of agentscope." + ) + raise ImportError(err_msg) + + +def _hash_string( + data: str, + hash_method: Literal["sha256", "md5", "sha1"], +) -> str: + """Hash the string data.""" + hash_func = getattr(hashlib, hash_method)() + hash_func.update(data.encode()) + return hash_func.hexdigest() + + +def _get_process_creation_time() -> datetime.datetime: + """Get the creation time of the process.""" + pid = os.getpid() + # Find the process by pid + current_process = psutil.Process(pid) + # Obtain the process creation time + create_time = current_process.create_time() + # Change the timestamp to a readable format + return datetime.datetime.fromtimestamp(create_time) + + +def _is_process_alive( + pid: int, + create_time_str: str, + create_time_format: str = "%Y-%m-%d %H:%M:%S", + tolerance_seconds: int = 10, +) -> bool: + """Check if the process is alive by comparing the actual creation time of + the process with the given creation time. + + Args: + pid (`int`): + The process id. + create_time_str (`str`): + The given creation time string. + create_time_format (`str`, defaults to `"%Y-%m-%d %H:%M:%S"`): + The format of the given creation time string. + tolerance_seconds (`int`, defaults to `10`): + The tolerance seconds for comparing the actual creation time with + the given creation time. + + Returns: + `bool`: True if the process is alive, False otherwise. + """ + try: + # Try to create a process object by pid + proc = psutil.Process(pid) + # Obtain the actual creation time of the process + actual_create_time_timestamp = proc.create_time() + + # Convert the given creation time string to a datetime object + given_create_time_datetime = datetime.datetime.strptime( + create_time_str, + create_time_format, + ) + + # Calculate the time difference between the actual creation time and + time_difference = abs( + actual_create_time_timestamp + - given_create_time_datetime.timestamp(), + ) + + # Compare the actual creation time with the given creation time + if time_difference <= tolerance_seconds: + return True + + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): + # If the process is not found, access is denied, or the process is a + # zombie process, return False + return False + + return False + + +def _is_windows() -> bool: + """Check if the system is Windows.""" + return os.name == "nt" + + +def _map_string_to_color_mark( + target_str: str, +) -> Tuple[str, str]: + """Map a string into an index within a given length. + + Args: + target_str (`str`): + The string to be mapped. + + Returns: + `Tuple[str, str]`: A color marker tuple + """ + color_marks = [ + ("\033[90m", "\033[0m"), + ("\033[91m", "\033[0m"), + ("\033[92m", "\033[0m"), + ("\033[93m", "\033[0m"), + ("\033[94m", "\033[0m"), + ("\033[95m", "\033[0m"), + ("\033[96m", "\033[0m"), + ("\033[97m", "\033[0m"), + ] + + hash_value = int(hashlib.sha256(target_str.encode()).hexdigest(), 16) + index = hash_value % len(color_marks) + return color_marks[index] + + +def _generate_new_runtime_id() -> str: + """Generate a new random runtime id.""" + _RUNTIME_ID_FORMAT = "run_%Y%m%d-%H%M%S_{}" + return _get_timestamp(_RUNTIME_ID_FORMAT).format( + _generate_random_code(uppercase=False), + ) diff --git a/src/agentscope/utils/tools.py b/src/agentscope/utils/tools.py deleted file mode 100644 index ab060e44d..000000000 --- a/src/agentscope/utils/tools.py +++ /dev/null @@ -1,479 +0,0 @@ -# -*- coding: utf-8 -*- -""" Tools for agentscope """ -import base64 -import datetime -import json -import os.path -import secrets -import string -import socket -import hashlib -import random -from typing import Any, Literal, List, Optional, Tuple - -from urllib.parse import urlparse -import psutil -import requests - - -def _get_timestamp( - format_: str = "%Y-%m-%d %H:%M:%S", - time: datetime.datetime = None, -) -> str: - """Get current timestamp.""" - if time is None: - return datetime.datetime.now().strftime(format_) - else: - return time.strftime(format_) - - -def to_openai_dict(item: dict) -> dict: - """Convert `Msg` to `dict` for OpenAI API.""" - clean_dict = {} - - if "name" in item: - clean_dict["name"] = item["name"] - - if "role" in item: - clean_dict["role"] = item["role"] - else: - clean_dict["role"] = "assistant" - - if "content" in item: - clean_dict["content"] = _convert_to_str(item["content"]) - else: - raise ValueError("The content of the message is missing.") - - return clean_dict - - -def to_dialog_str(item: dict) -> str: - """Convert a dict into string prompt style.""" - speaker = item.get("name", None) or item.get("role", None) - content = item.get("content", None) - - if content is None: - return str(item) - - if speaker is None: - return content - else: - return f"{speaker}: {content}" - - -def find_available_port() -> int: - """Get an unoccupied socket port number.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def check_port(port: Optional[int] = None) -> int: - """Check if the port is available. - - Args: - port (`int`): - the port number being checked. - - Returns: - `int`: the port number that passed the check. If the port is found - to be occupied, an available port number will be automatically - returned. - """ - if port is None: - new_port = find_available_port() - return new_port - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - if s.connect_ex(("localhost", port)) == 0: - raise RuntimeError("Port is occupied.") - except Exception: - new_port = find_available_port() - return new_port - return port - - -def _guess_type_by_extension( - url: str, -) -> Literal["image", "audio", "video", "file"]: - """Guess the type of the file by its extension.""" - extension = url.split(".")[-1].lower() - - if extension in [ - "bmp", - "dib", - "icns", - "ico", - "jfif", - "jpe", - "jpeg", - "jpg", - "j2c", - "j2k", - "jp2", - "jpc", - "jpf", - "jpx", - "apng", - "png", - "bw", - "rgb", - "rgba", - "sgi", - "tif", - "tiff", - "webp", - ]: - return "image" - elif extension in [ - "amr", - "wav", - "3gp", - "3gpp", - "aac", - "mp3", - "flac", - "ogg", - ]: - return "audio" - elif extension in [ - "mp4", - "webm", - "mkv", - "flv", - "avi", - "mov", - "wmv", - "rmvb", - ]: - return "video" - else: - return "file" - - -def _to_openai_image_url(url: str) -> str: - """Convert an image url to openai format. If the given url is a local - file, it will be converted to base64 format. Otherwise, it will be - returned directly. - - Args: - url (`str`): - The local or public url of the image. - """ - # See https://platform.openai.com/docs/guides/vision for details of - # support image extensions. - support_image_extensions = ( - ".png", - ".jpg", - ".jpeg", - ".gif", - ".webp", - ) - - parsed_url = urlparse(url) - - lower_url = url.lower() - - # Web url - if parsed_url.scheme != "": - if any(lower_url.endswith(_) for _ in support_image_extensions): - return url - - # Check if it is a local file - elif os.path.exists(url) and os.path.isfile(url): - if any(lower_url.endswith(_) for _ in support_image_extensions): - with open(url, "rb") as image_file: - base64_image = base64.b64encode(image_file.read()).decode( - "utf-8", - ) - extension = parsed_url.path.lower().split(".")[-1] - mime_type = f"image/{extension}" - return f"data:{mime_type};base64,{base64_image}" - - raise TypeError(f"{url} should be end with {support_image_extensions}.") - - -def _download_file(url: str, path_file: str, max_retries: int = 3) -> bool: - """Download file from the given url and save it to the given path. - - Args: - url (`str`): - The url of the file. - path_file (`str`): - The path to save the file. - max_retries (`int`, defaults to `3`) - The maximum number of retries when fail to download the file. - """ - for n_retry in range(1, max_retries + 1): - response = requests.get(url, stream=True) - if response.status_code == requests.codes.ok: - with open(path_file, "wb") as file: - for chunk in response.iter_content(1024): - file.write(chunk) - return True - else: - raise RuntimeError( - f"Failed to download file from {url} (status code: " - f"{response.status_code}). Retry {n_retry}/{max_retries}.", - ) - return False - - -def _generate_random_code( - length: int = 6, - uppercase: bool = True, - lowercase: bool = True, - digits: bool = True, -) -> str: - """Get random code.""" - characters = "" - if uppercase: - characters += string.ascii_uppercase - if lowercase: - characters += string.ascii_lowercase - if digits: - characters += string.digits - return "".join(secrets.choice(characters) for i in range(length)) - - -def generate_id_from_seed(seed: str, length: int = 8) -> str: - """Generate random id from seed str. - - Args: - seed (`str`): seed string. - length (`int`): generated id length. - """ - hasher = hashlib.sha256() - hasher.update(seed.encode("utf-8")) - hash_digest = hasher.hexdigest() - - random.seed(hash_digest) - id_chars = [ - random.choice(string.ascii_letters + string.digits) - for _ in range(length) - ] - return "".join(id_chars) - - -def _is_web_url(url: str) -> bool: - """Whether the url is accessible from the Web. - - Args: - url (`str`): - The url to check. - - Note: - This function is not perfect, it only checks if the URL starts with - common web protocols, e.g., http, https, ftp, oss. - """ - parsed_url = urlparse(url) - return parsed_url.scheme in ["http", "https", "ftp", "oss"] - - -def _is_json_serializable(obj: Any) -> bool: - """Check if the given object is json serializable.""" - try: - json.dumps(obj) - return True - except TypeError: - return False - - -def _convert_to_str(content: Any) -> str: - """Convert the content to string. - - Note: - For prompt engineering, simply calling `str(content)` or - `json.dumps(content)` is not enough. - - - For `str(content)`, if `content` is a dictionary, it will turn double - quotes to single quotes. When this string is fed into prompt, the LLMs - may learn to use single quotes instead of double quotes (which - cannot be loaded by `json.loads` API). - - - For `json.dumps(content)`, if `content` is a string, it will add - double quotes to the string. LLMs may learn to use double quotes to - wrap strings, which leads to the same issue as `str(content)`. - - To avoid these issues, we use this function to safely convert the - content to a string used in prompt. - - Args: - content (`Any`): - The content to be converted. - - Returns: - `str`: The converted string. - """ - - if isinstance(content, str): - return content - elif isinstance(content, (dict, list, int, float, bool, tuple)): - return json.dumps(content, ensure_ascii=False) - else: - return str(content) - - -def _join_str_with_comma_and(elements: List[str]) -> str: - """Return the JSON string with comma, and use " and " between the last two - elements.""" - - if len(elements) == 0: - return "" - elif len(elements) == 1: - return elements[0] - elif len(elements) == 2: - return " and ".join(elements) - else: - return ", ".join(elements[:-1]) + f", and {elements[-1]}" - - -class ImportErrorReporter: - """Used as a placeholder for missing packages. - When called, an ImportError will be raised, prompting the user to install - the specified extras requirement. - """ - - def __init__(self, error: ImportError, extras_require: str = None) -> None: - """Init the ImportErrorReporter. - - Args: - error (`ImportError`): the original ImportError. - extras_require (`str`): the extras requirement. - """ - self.error = error - self.extras_require = extras_require - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self._raise_import_error() - - def __getattr__(self, name: str) -> Any: - return self._raise_import_error() - - def __getitem__(self, __key: Any) -> Any: - return self._raise_import_error() - - def _raise_import_error(self) -> Any: - """Raise the ImportError""" - err_msg = f"ImportError occorred: [{self.error.msg}]." - if self.extras_require is not None: - err_msg += ( - f" Please install [{self.extras_require}] version" - " of agentscope." - ) - raise ImportError(err_msg) - - -def _hash_string( - data: str, - hash_method: Literal["sha256", "md5", "sha1"], -) -> str: - """Hash the string data.""" - hash_func = getattr(hashlib, hash_method)() - hash_func.update(data.encode()) - return hash_func.hexdigest() - - -def _get_process_creation_time() -> datetime.datetime: - """Get the creation time of the process.""" - pid = os.getpid() - # Find the process by pid - current_process = psutil.Process(pid) - # Obtain the process creation time - create_time = current_process.create_time() - # Change the timestamp to a readable format - return datetime.datetime.fromtimestamp(create_time) - - -def _is_process_alive( - pid: int, - create_time_str: str, - create_time_format: str = "%Y-%m-%d %H:%M:%S", - tolerance_seconds: int = 10, -) -> bool: - """Check if the process is alive by comparing the actual creation time of - the process with the given creation time. - - Args: - pid (`int`): - The process id. - create_time_str (`str`): - The given creation time string. - create_time_format (`str`, defaults to `"%Y-%m-%d %H:%M:%S"`): - The format of the given creation time string. - tolerance_seconds (`int`, defaults to `10`): - The tolerance seconds for comparing the actual creation time with - the given creation time. - - Returns: - `bool`: True if the process is alive, False otherwise. - """ - try: - # Try to create a process object by pid - proc = psutil.Process(pid) - # Obtain the actual creation time of the process - actual_create_time_timestamp = proc.create_time() - - # Convert the given creation time string to a datetime object - given_create_time_datetime = datetime.datetime.strptime( - create_time_str, - create_time_format, - ) - - # Calculate the time difference between the actual creation time and - time_difference = abs( - actual_create_time_timestamp - - given_create_time_datetime.timestamp(), - ) - - # Compare the actual creation time with the given creation time - if time_difference <= tolerance_seconds: - return True - - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): - # If the process is not found, access is denied, or the process is a - # zombie process, return False - return False - - return False - - -def _is_windows() -> bool: - """Check if the system is Windows.""" - return os.name == "nt" - - -def _map_string_to_color_mark( - target_str: str, -) -> Tuple[str, str]: - """Map a string into an index within a given length. - - Args: - target_str (`str`): - The string to be mapped. - - Returns: - `Tuple[str, str]`: A color marker tuple - """ - color_marks = [ - ("\033[90m", "\033[0m"), - ("\033[91m", "\033[0m"), - ("\033[92m", "\033[0m"), - ("\033[93m", "\033[0m"), - ("\033[94m", "\033[0m"), - ("\033[95m", "\033[0m"), - ("\033[96m", "\033[0m"), - ("\033[97m", "\033[0m"), - ] - - hash_value = int(hashlib.sha256(target_str.encode()).hexdigest(), 16) - index = hash_value % len(color_marks) - return color_marks[index] - - -def _generate_new_runtime_id() -> str: - """Generate a new random runtime id.""" - _RUNTIME_ID_FORMAT = "run_%Y%m%d-%H%M%S_{}" - return _get_timestamp(_RUNTIME_ID_FORMAT).format( - _generate_random_code(uppercase=False), - ) diff --git a/src/agentscope/web/workstation/workflow_node.py b/src/agentscope/web/workstation/workflow_node.py index 183f67883..337c97efe 100644 --- a/src/agentscope/web/workstation/workflow_node.py +++ b/src/agentscope/web/workstation/workflow_node.py @@ -9,7 +9,6 @@ from agentscope.agents import ( DialogAgent, UserAgent, - TextToImageAgent, DictDialogAgent, ReActAgent, ) @@ -220,36 +219,6 @@ def compile(self) -> dict: } -class TextToImageAgentNode(WorkflowNode): - """ - A node representing a TextToImageAgent within a workflow. - """ - - node_type = WorkflowNodeType.AGENT - - def __init__( - self, - node_id: str, - opt_kwargs: dict, - source_kwargs: dict, - dep_opts: list, - ) -> None: - super().__init__(node_id, opt_kwargs, source_kwargs, dep_opts) - self.pipeline = TextToImageAgent(**self.opt_kwargs) - - def __call__(self, x: dict = None) -> dict: - return self.pipeline(x) - - def compile(self) -> dict: - return { - "imports": "from agentscope.agents import TextToImageAgent", - "inits": f"{self.var_name} = TextToImageAgent(" - f"{kwarg_converter(self.opt_kwargs)})", - "execs": f"{DEFAULT_FLOW_VAR} = {self.var_name}" - f"({DEFAULT_FLOW_VAR})", - } - - class DictDialogAgentNode(WorkflowNode): """ A node representing a DictDialogAgent within a workflow. @@ -840,7 +809,6 @@ def compile(self) -> dict: "Message": MsgNode, "DialogAgent": DialogAgentNode, "UserAgent": UserAgentNode, - "TextToImageAgent": TextToImageAgentNode, "DictDialogAgent": DictDialogAgentNode, "ReActAgent": ReActAgentNode, "Placeholder": PlaceHolderNode, diff --git a/tests/openai_services_test.py b/tests/openai_services_test.py index 997b5fa6e..d875fc3b1 100644 --- a/tests/openai_services_test.py +++ b/tests/openai_services_test.py @@ -4,7 +4,6 @@ from unittest.mock import patch, MagicMock, mock_open import os import shutil -from openai._types import NOT_GIVEN from agentscope.manager import ASManager from agentscope.service.multi_modality.openai_services import ( @@ -177,7 +176,7 @@ def test_openai_text_to_image_service_error( # Ensure _download_file is not called in case of service error mock_download_file.assert_not_called() - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch( "builtins.open", new_callable=mock_open, @@ -212,7 +211,7 @@ def test_openai_audio_to_text_success( {"transcription": "This is a test transcription."}, ) - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch("builtins.open", new_callable=mock_open) def test_openai_audio_to_text_error( self, @@ -238,7 +237,7 @@ def test_openai_audio_to_text_error( result.content, ) - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") def test_successful_audio_generation(self, mock_openai: MagicMock) -> None: """Test the openai_text_to_audio function with a valid text.""" # Mocking the OpenAI API response @@ -264,7 +263,7 @@ def test_successful_audio_generation(self, mock_openai: MagicMock) -> None: expected_audio_path, ) # Check file save - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") def test_api_error_text_to_audio(self, mock_openai: MagicMock) -> None: """Test the openai_text_to_audio function with an API error.""" # Mocking an OpenAI API error @@ -352,7 +351,7 @@ def test_openai_image_to_text_error( self.assertEqual(result.status, ServiceExecStatus.ERROR) self.assertEqual(result.content, "API Error") - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch("agentscope.service.multi_modality.openai_services._parse_url") @patch( ( @@ -411,9 +410,12 @@ def test_openai_edit_image_success( ) # Check if _handle_openai_img_response was called - mock_handle_response.assert_called_once_with(mock_response, None) + mock_handle_response.assert_called_once_with( + mock_response.model_dump(), + None, + ) - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch("agentscope.service.multi_modality.openai_services._parse_url") def test_openai_edit_image_error( self, @@ -444,13 +446,12 @@ def test_openai_edit_image_error( mock_client.images.edit.assert_called_once_with( model="dall-e-2", image="parsed_original_image.png", - mask=NOT_GIVEN, prompt="Add a sun to the sky", n=1, size="256x256", ) - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch("agentscope.service.multi_modality.openai_services._parse_url") @patch( ( @@ -464,7 +465,7 @@ def test_openai_create_image_variation_success( mock_parse_url: MagicMock, mock_openai: MagicMock, ) -> None: - """Test the openai_create_image_variation swith a valid image URL.""" + """Test the openai_create_image_variation with a valid image URL.""" # Mock OpenAI client mock_client = MagicMock() mock_openai.return_value = mock_client @@ -505,9 +506,12 @@ def test_openai_create_image_variation_success( ) # Check if _handle_openai_img_response was called - mock_handle_response.assert_called_once_with(mock_response, None) + mock_handle_response.assert_called_once_with( + mock_response.model_dump(), + None, + ) - @patch("agentscope.service.multi_modality.openai_services.OpenAI") + @patch("openai.OpenAI") @patch("agentscope.service.multi_modality.openai_services._parse_url") def test_openai_create_image_variation_error( self, diff --git a/tests/prompt_engine_test.py b/tests/prompt_engine_test.py deleted file mode 100644 index 046ef40ed..000000000 --- a/tests/prompt_engine_test.py +++ /dev/null @@ -1,137 +0,0 @@ -# -*- coding: utf-8 -*- -"""Unit test for prompt engine.""" -import unittest -from typing import Any - -import agentscope -from agentscope.manager import ModelManager -from agentscope.models import ModelResponse -from agentscope.models import OpenAIWrapperBase -from agentscope.prompt import PromptEngine - - -class PromptEngineTest(unittest.TestCase): - """Unit test for prompt engine.""" - - def setUp(self) -> None: - """Init for PromptEngineTest.""" - self.name = "white" - self.sys_prompt = ( - "You're a player in a chess game, and you are playing {name}." - ) - self.dialog_history = [ - {"name": "white player", "content": "Move to E4."}, - {"name": "black player", "content": "Okay, I moved to F4."}, - {"name": "white player", "content": "Move to F5."}, - ] - self.hint = "Now decide your next move." - self.prefix = "{name} player: " - - agentscope.init( - model_configs=[ - { - "model_type": "post_api", - "config_name": "open-source", - "api_url": "http://xxx", - "headers": {"Autherization": "Bearer {API_TOKEN}"}, - "parameters": { - "temperature": 0.5, - }, - }, - { - "model_type": "openai_chat", - "config_name": "gpt-4", - "model_name": "gpt-4", - "api_key": "xxx", - "organization": "xxx", - }, - ], - disable_saving=True, - ) - - def test_list_prompt(self) -> None: - """Test for list prompt.""" - - class TestModelWrapperBase(OpenAIWrapperBase): - """Test model wrapper.""" - - def __init__(self) -> None: - self.max_length = 1000 - - def __call__( - self, - *args: Any, - **kwargs: Any, - ) -> ModelResponse: - return ModelResponse(text="") - - def _register_default_metrics(self) -> None: - pass - - model = TestModelWrapperBase() - engine = PromptEngine(model) - - prompt = engine.join( - self.sys_prompt, - self.dialog_history, - self.hint, - format_map={"name": self.name}, - ) - - self.assertEqual( - [ - { - "role": "assistant", - "content": "You're a player in a chess game, and you are " - "playing white.", - }, - { - "name": "white player", - "role": "assistant", - "content": "Move to E4.", - }, - { - "name": "black player", - "role": "assistant", - "content": "Okay, I moved to F4.", - }, - { - "name": "white player", - "role": "assistant", - "content": "Move to F5.", - }, - { - "role": "assistant", - "content": "Now decide your next move.", - }, - ], - prompt, - ) - - def test_str_prompt(self) -> None: - """Test for string prompt.""" - model_manager = ModelManager.get_instance() - model = model_manager.get_model_by_config_name("open-source") - engine = PromptEngine(model) - - prompt = engine.join( - self.sys_prompt, - self.dialog_history, - self.hint, - self.prefix, - format_map={"name": self.name}, - ) - - self.assertEqual( - """You're a player in a chess game, and you are playing white. -white player: Move to E4. -black player: Okay, I moved to F4. -white player: Move to F5. -Now decide your next move. -white player: """, - prompt, - ) - - -if __name__ == "__main__": - unittest.main()