diff --git a/README.md b/README.md index b10e828ae..c450685d8 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ services and third-party model APIs. | ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_chat_template.json) | llama3, llama2, Mistral, ... | | | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_embedding_template.json) | llama2, Mistral, ... | | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_generate_template.json) | llama2, Mistral, ... | +| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#litellm-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/litellm_chat_template.json) | [models supported by litellm](https://docs.litellm.ai/docs/)... | | Post Request based API | - | [`PostAPIModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#post-request-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/postapi_model_config_template.json) | - | **Supported Local Model Deployment** diff --git a/README_ZH.md b/README_ZH.md index 3320b445f..47ad48b9e 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -76,6 +76,7 @@ AgentScope提供了一系列`ModelWrapper`来支持本地模型服务和第三 | ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_chat_template.json) | llama3, llama2, Mistral, ... | | | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_embedding_template.json) | llama2, Mistral, ... | | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#ollama-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/ollama_generate_template.json) | llama2, Mistral, ... | +| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#litellm-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/litellm_chat_template.json) | [models supported by litellm](https://docs.litellm.ai/docs/)... | | Post Request based API | - | [`PostAPIModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | [guidance](https://modelscope.github.io/agentscope/en/tutorial/203-model.html#post-request-api)
[template](https://github.com/modelscope/agentscope/blob/main/examples/model_configs_template/postapi_model_config_template.json) | - | **支持的本地模型部署** diff --git a/docs/sphinx_doc/en/source/tutorial/203-model.md b/docs/sphinx_doc/en/source/tutorial/203-model.md index 08ef18dc5..d6e153d0f 100644 --- a/docs/sphinx_doc/en/source/tutorial/203-model.md +++ b/docs/sphinx_doc/en/source/tutorial/203-model.md @@ -16,6 +16,7 @@ Currently, AgentScope supports the following model service APIs: - Gemini API, including chat and embedding. - ZhipuAI API, including chat and embedding. - Ollama API, including chat, embedding and generation. +- LiteLLM API, including chat, with various model APIs. - Post Request API, model inference services based on Post requests, including Huggingface/ModelScope Inference API and various post request based model APIs. @@ -87,6 +88,7 @@ In the current AgentScope, the supported `model_type` types, the corresponding | ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... | | | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... | | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... | +| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - | | Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - | | | Chat | [`PostAPIChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... | @@ -440,6 +442,26 @@ Here we provide example configurations for different model wrappers.
+ +#### LiteLLM Chat API + +
+LiteLLM Chat API (agentscope.models.LiteLLMChatModelWrapper) + +```python +{ + "config_name": "lite_llm_openai_chat_gpt-3.5-turbo", + "model_type": "litellm_chat", + "model_name": "gpt-3.5-turbo" # You should note that for different models, you should set the corresponding environment variables, such as OPENAI_API_KEY, etc. You may refer to https://docs.litellm.ai/docs/ for this. +}, +``` + +
+ +
+ + #### Post Request Chat API
diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md index 0528abae8..7b912cbf2 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/203-model.md @@ -13,6 +13,7 @@ AgentScope中,模型的部署和调用是通过`ModelWrapper`来解耦开的 - Gemini API,包括对话(Chat)和嵌入(Embedding)。 - ZhipuAi API,包括对话(Chat)和嵌入(Embedding)。 - Ollama API,包括对话(Chat),嵌入(Embedding)和生成(Generation)。 +- LiteLLM API, 包括对话(Chat), 支持各种模型的API. - Post请求API,基于Post请求实现的模型推理服务,包括Huggingface/ModelScope Inference API和各种符合Post请求格式的API。 @@ -107,6 +108,7 @@ API如下: | ollama | Chat | [`OllamaChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_chat"` | llama2, ... | | | Embedding | [`OllamaEmbeddingWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_embedding"` | llama2, ... | | | Generation | [`OllamaGenerationWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/ollama_model.py) | `"ollama_generate"` | llama2, ... | +| LiteLLM API | Chat | [`LiteLLMChatWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/litellm_model.py) | `"litellm_chat"` | - | | Post Request based API | - | [`PostAPIModelWrapperBase`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api"` | - | | | Chat | [`PostAPIChatModelWrapper`](https://github.com/modelscope/agentscope/blob/main/src/agentscope/models/post_model.py) | `"post_api_chat"` | meta-llama/Meta-Llama-3-8B-Instruct, ... | @@ -435,6 +437,26 @@ API如下:
+ +#### LiteLLM Chat API + +
+LiteLLM Chat API (agentscope.models.LiteLLMChatModelWrapper) + +```python +{ + "config_name": "lite_llm_openai_chat_gpt-3.5-turbo", + "model_type": "litellm_chat", + "model_name": "gpt-3.5-turbo" # You should note that for different models, you should set the corresponding environment variables, such as OPENAI_API_KEY, etc. You may refer to https://docs.litellm.ai/docs/ for this. +}, +``` + +
+ +
+ + #### Post Request API
diff --git a/examples/model_configs_template/litellm_chat_template.json b/examples/model_configs_template/litellm_chat_template.json new file mode 100644 index 000000000..f1711dca9 --- /dev/null +++ b/examples/model_configs_template/litellm_chat_template.json @@ -0,0 +1,11 @@ +[{ + "config_name": "lite_llm_openai_chat_gpt-3.5-turbo", + "model_type": "litellm_chat", + "model_name": "gpt-3.5-turbo" +}, +{ + "config_name": "lite_llm_claude3", + "model_type": "litellm_chat", + "model_name": "claude-3-opus-20240229" +} +] diff --git a/setup.py b/setup.py index 7dca2181b..2259f592f 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ "ollama>=0.1.7", "google-generativeai>=0.4.0", "zhipuai", + "litellm", ] distribute_requires = minimal_requires + rpc_requires diff --git a/src/agentscope/_init.py b/src/agentscope/_init.py index 7d1f44d7b..dff68e585 100644 --- a/src/agentscope/_init.py +++ b/src/agentscope/_init.py @@ -25,7 +25,7 @@ def init( save_dir: str = _DEFAULT_DIR, save_log: bool = True, save_code: bool = True, - save_api_invoke: bool = True, + save_api_invoke: bool = False, use_monitor: bool = True, logger_level: LOG_LEVEL = _DEFAULT_LOG_LEVEL, runtime_id: Optional[str] = None, diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index 1e607c0e4..832829993 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -37,6 +37,9 @@ ZhipuAIChatWrapper, ZhipuAIEmbeddingWrapper, ) +from .litellm_model import ( + LiteLLMChatWrapper, +) __all__ = [ @@ -59,6 +62,7 @@ "GeminiEmbeddingWrapper", "ZhipuAIChatWrapper", "ZhipuAIEmbeddingWrapper", + "LiteLLMChatWrapper", "load_model_by_config_name", "read_model_configs", "clear_model_configs", diff --git a/src/agentscope/models/litellm_model.py b/src/agentscope/models/litellm_model.py new file mode 100644 index 000000000..242830a38 --- /dev/null +++ b/src/agentscope/models/litellm_model.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- +"""Model wrapper based on litellm https://docs.litellm.ai/docs/""" +from abc import ABC +from typing import Union, Any, List, Sequence + +from loguru import logger + +from .model import ModelWrapperBase, ModelResponse +from ..message import MessageBase +from ..utils.tools import _convert_to_str + +try: + import litellm +except ImportError: + litellm = None + + +class LiteLLMWrapperBase(ModelWrapperBase, ABC): + """The model wrapper based on LiteLLM API.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """ + To use the LiteLLM wrapper, environent variables must be set. + Different model_name could be using different environment variables. + For example: + - for model_name: "gpt-3.5-turbo", you need to set "OPENAI_API_KEY" + ``` + os.environ["OPENAI_API_KEY"] = "your-api-key" + ``` + - for model_name: "claude-2", you need to set "ANTHROPIC_API_KEY" + - for Azure OpenAI, you need to set "AZURE_API_KEY", + "AZURE_API_BASE", "AZURE_API_VERSION" + You should refer to the docs in https://docs.litellm.ai/docs/ . + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in OpenAI API. + generate_args (`dict`, default `None`): + The extra keyword arguments used in litellm api generation, + e.g. `temperature`, `seed`. + For generate_args, please refer to + https://docs.litellm.ai/docs/completion/input + for more detailes. + + """ + + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + super().__init__(config_name=config_name) + + if litellm is None: + raise ImportError( + "Cannot import litellm package in current python environment." + "You should try:" + "1. Install litellm by `pip install litellm`" + "2. If you still have import error, you should try to " + "update the openai to higher version, e.g. " + "by runing `pip install openai==1.25.1", + ) + + self.model_name = model_name + self.generate_args = generate_args or {} + self._register_default_metrics() + + def format( + self, + *args: Union[MessageBase, Sequence[MessageBase]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class LiteLLMChatWrapper(LiteLLMWrapperBase): + """The model wrapper based on litellm chat API. + To use the LiteLLM wrapper, environent variables must be set. + Different model_name could be using different environment variables. + For example: + - for model_name: "gpt-3.5-turbo", you need to set "OPENAI_API_KEY" + ``` + os.environ["OPENAI_API_KEY"] = "your-api-key" + ``` + - for model_name: "claude-2", you need to set "ANTHROPIC_API_KEY" + - for Azure OpenAI, you need to set "AZURE_API_KEY", + "AZURE_API_BASE", "AZURE_API_VERSION" + You should refer to the docs in https://docs.litellm.ai/docs/ . + """ + + model_type: str = "litellm_chat" + + def _register_default_metrics(self) -> None: + # Set monitor accordingly + # TODO: set quota to the following metrics + self.monitor.register( + self._metric("call_counter"), + metric_unit="times", + ) + self.monitor.register( + self._metric("prompt_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("completion_tokens"), + metric_unit="token", + ) + self.monitor.register( + self._metric("total_tokens"), + metric_unit="token", + ) + + def __call__( + self, + messages: list, + **kwargs: Any, + ) -> ModelResponse: + """ + Args: + messages (`list`): + A list of messages to process. + **kwargs (`Any`): + The keyword arguments to litellm chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please refer to + https://docs.litellm.ai/docs/completion/input + for more detailed arguments. + + Returns: + `ModelResponse`: + The response text in text field, and the raw response in + raw field. + """ + + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not isinstance(messages, list): + raise ValueError( + "LiteLLM `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + if not all("role" in msg and "content" in msg for msg in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for LiteLLM API.", + ) + + # step3: forward to generate response + response = litellm.completion( + model=self.model_name, + messages=messages, + **kwargs, + ) + + # step4: record the api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model_name, + "messages": messages, + **kwargs, + }, + response=response.model_dump(), + ) + + # step5: update monitor accordingly + self.update_monitor(call_counter=1, **response.usage.model_dump()) + + # step6: return response + return ModelResponse( + text=response.choices[0].message.content, + raw=response.model_dump(), + ) + + def format( + self, + *args: Union[MessageBase, Sequence[MessageBase]], + ) -> List[dict]: + """Format the input string and dictionary into the unified format. + Note that the format function might not be the optimal way to contruct + prompt for every model, but a common way to do so. + Developers are encouraged to implement their own prompt + engineering strategies if have strong performance concerns. + + Args: + args (`Union[MessageBase, Sequence[MessageBase]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + Returns: + `List[dict]`: + The formatted messages in the format that anthropic Chat API + required. + """ + + # Parse all information into a list of messages + input_msgs = [] + for _ in args: + if _ is None: + continue + if isinstance(_, MessageBase): + input_msgs.append(_) + elif isinstance(_, list) and all( + isinstance(__, MessageBase) for __ in _ + ): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", + ) + + # record dialog history as a list of strings + system_content_template = [] + dialogue = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role == "system": + # system prompt + system_prompt = _convert_to_str(unit.content) + if not system_prompt.endswith("\n"): + system_prompt += "\n" + system_content_template.append(system_prompt) + else: + # Merge all messages into a dialogue history prompt + dialogue.append( + f"{unit.name}: {_convert_to_str(unit.content)}", + ) + + if len(dialogue) != 0: + system_content_template.extend( + ["## Dialogue History", "{dialogue_history}"], + ) + + dialogue_history = "\n".join(dialogue) + + system_content_template = "\n".join(system_content_template) + + messages = [ + { + "role": "user", + "content": system_content_template.format( + dialogue_history=dialogue_history, + ), + }, + ] + + return messages diff --git a/tests/format_test.py b/tests/format_test.py index 3226deab4..949e9a422 100644 --- a/tests/format_test.py +++ b/tests/format_test.py @@ -12,6 +12,7 @@ ZhipuAIChatWrapper, DashScopeChatWrapper, DashScopeMultiModalWrapper, + LiteLLMChatWrapper, ) @@ -204,6 +205,32 @@ def test_zhipuai_chat(self) -> None: with self.assertRaises(TypeError): model.format(*self.wrong_inputs) # type: ignore[arg-type] + def test_litellm_chat(self) -> None: + """Unit test for the format function in litellm chat api wrapper.""" + model = LiteLLMChatWrapper( + config_name="", + model_name="gpt-3.5-turbo", + api_key="xxx", + ) + + ground_truth = [ + { + "role": "user", + "content": ( + "You are a helpful assistant\n\n" + "## Dialogue History\nuser: What is the weather today?\n" + "assistant: It is sunny today" + ), + }, + ] + + prompt = model.format(*self.inputs) + self.assertListEqual(prompt, ground_truth) + + # wrong format + with self.assertRaises(TypeError): + model.format(*self.wrong_inputs) # type: ignore[arg-type] + def test_dashscope_multimodal_image(self) -> None: """Unit test for the format function in dashscope multimodal conversation api wrapper for image.""" diff --git a/tests/litellm_test.py b/tests/litellm_test.py new file mode 100644 index 000000000..3ee4a8503 --- /dev/null +++ b/tests/litellm_test.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +"""litellm test""" +import unittest +from unittest.mock import patch, MagicMock + +import agentscope +from agentscope.models import load_model_by_config_name + + +class TestLiteLLMChatWrapper(unittest.TestCase): + """Test LiteLLM Chat Wrapper""" + + def setUp(self) -> None: + self.api_key = "test_api_key.secret_key" + self.messages = [ + {"role": "user", "content": "Hello, litellm!"}, + {"role": "assistant", "content": "How can I assist you?"}, + ] + + @patch("agentscope.models.litellm_model.litellm") + def test_chat(self, mock_litellm: MagicMock) -> None: + """ + Test chat""" + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [ + {"message": {"content": "Hello, this is a mocked response!"}}, + ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 5, + "total_tokens": 105, + }, + } + mock_response.choices[ + 0 + ].message.content = "Hello, this is a mocked response!" + + mock_litellm.completion.return_value = mock_response + + agentscope.init( + model_configs={ + "config_name": "test_config", + "model_type": "litellm_chat", + "model_name": "ollama/llama3:8b", + "api_key": self.api_key, + }, + ) + + model = load_model_by_config_name("test_config") + + response = model( + messages=self.messages, + api_base="http://localhost:11434", + ) + + self.assertEqual(response.text, "Hello, this is a mocked response!") + + +if __name__ == "__main__": + unittest.main()