diff --git a/README.md b/README.md index 9472ba8..0bbf608 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ class EchoApplication(ChatCompletion): # Generate response with a single choice with response.create_single_choice() as choice: # Fill the content of the response with the last user's content - choice.append_content(last_user_message.content or "") + choice.append_content(last_user_message.text) # DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications diff --git a/aidial_sdk/chat_completion/request.py b/aidial_sdk/chat_completion/request.py index fe77596..a3821a1 100644 --- a/aidial_sdk/chat_completion/request.py +++ b/aidial_sdk/chat_completion/request.py @@ -1,8 +1,11 @@ from enum import Enum from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from typing_extensions import assert_never + from aidial_sdk.chat_completion.enums import Status from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin +from aidial_sdk.exceptions import InvalidRequestError from aidial_sdk.pydantic_v1 import ( ConstrainedFloat, ConstrainedInt, @@ -85,6 +88,20 @@ class Message(ExtraForbidModel): tool_call_id: Optional[StrictStr] = None function_call: Optional[FunctionCall] = None + @property + def text(self) -> str: + def _error_message(actual: str) -> str: + return f"Unable to retrieve text content of the message: the actual content is {actual}." + + if self.content is None: + raise InvalidRequestError(_error_message("null or missing")) + elif isinstance(self.content, str): + return self.content + elif isinstance(self.content, list): + raise InvalidRequestError(_error_message("a list of content parts")) + else: + assert_never(self.content) + class Addon(ExtraForbidModel): name: Optional[StrictStr] = None diff --git a/examples/echo/app.py b/examples/echo/app.py index 8304e25..efd55d6 100644 --- a/examples/echo/app.py +++ b/examples/echo/app.py @@ -20,11 +20,7 @@ async def chat_completion( # Generate response with a single choice with response.create_single_choice() as choice: # Fill the content of the response with the last user's content - choice.append_content( - last_message.content - if isinstance(last_message.content, str) - else "" - ) + choice.append_content(last_message.text) if last_message.custom_content is not None: for attachment in last_message.custom_content.attachments or []: diff --git a/examples/langchain_rag/app.py b/examples/langchain_rag/app.py index ee2a979..75ce290 100644 --- a/examples/langchain_rag/app.py +++ b/examples/langchain_rag/app.py @@ -61,7 +61,7 @@ async def chat_completion( with response.create_single_choice() as choice: message = request.messages[-1] - user_query = message.content or "" + user_query = message.text file_url = get_last_attachment_url(request.messages) file_abs_url = urljoin(f"{DIAL_URL}/v1/", file_url) diff --git a/examples/render_text/app/main.py b/examples/render_text/app/main.py index 88afa89..f7349ee 100644 --- a/examples/render_text/app/main.py +++ b/examples/render_text/app/main.py @@ -23,13 +23,12 @@ async def chat_completion(self, request: Request, response: Response): # Create a single choice with response.create_single_choice() as choice: # Get the last message content - content = request.messages[-1].content - content_text = content if isinstance(content, str) else "" + content = request.messages[-1].text # The image may be returned either as base64 string or as URL # The content specifies the mode of return: 'base64' or 'url' try: - command, text = content_text.split(",", 1) + command, text = content.split(",", 1) if command not in ["base64", "url"]: raise DIALException( message="The command must be either 'base64' or 'url'", diff --git a/tests/applications/broken_immediately.py b/tests/applications/broken_immediately.py index 46edea6..e6d8851 100644 --- a/tests/applications/broken_immediately.py +++ b/tests/applications/broken_immediately.py @@ -2,7 +2,6 @@ from aidial_sdk import HTTPException as DIALException from aidial_sdk.chat_completion import ChatCompletion, Request, Response -from tests.utils.request import get_message_text_content def raise_exception(exception_type: str): @@ -28,4 +27,4 @@ class BrokenApplication(ChatCompletion): async def chat_completion( self, request: Request, response: Response ) -> None: - raise_exception(get_message_text_content(request.messages[0])) + raise_exception(request.messages[0].text) diff --git a/tests/applications/broken_in_runtime.py b/tests/applications/broken_in_runtime.py index 62422db..4e22ef6 100644 --- a/tests/applications/broken_in_runtime.py +++ b/tests/applications/broken_in_runtime.py @@ -1,6 +1,5 @@ from aidial_sdk.chat_completion import ChatCompletion, Request, Response from tests.applications.broken_immediately import raise_exception -from tests.utils.request import get_message_text_content class RuntimeBrokenApplication(ChatCompletion): @@ -18,4 +17,4 @@ async def chat_completion( choice.append_content("Test content") await response.aflush() - raise_exception(get_message_text_content(request.messages[0])) + raise_exception(request.messages[0].text) diff --git a/tests/applications/echo.py b/tests/applications/echo.py index 390e47c..169eeb5 100644 --- a/tests/applications/echo.py +++ b/tests/applications/echo.py @@ -6,7 +6,6 @@ TruncatePromptRequest, TruncatePromptResponse, ) -from tests.utils.request import get_message_text_content from tests.utils.tokenization import ( default_truncate_prompt, make_batched_tokenize, @@ -28,7 +27,7 @@ async def chat_completion( response.set_response_id("test_id") response.set_created(0) - content = get_message_text_content(request.messages[-1]) + content = request.messages[-1].text with response.create_single_choice() as choice: choice.append_content(content) diff --git a/tests/test_errors.py b/tests/test_errors.py index a0caa94..37468cb 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -51,6 +51,28 @@ } }, ), + ( + None, + 400, + { + "error": { + "message": "Unable to retrieve text content of the message: the actual content is null or missing.", + "type": "invalid_request_error", + "code": "400", + } + }, + ), + ( + [{"type": "text", "text": "hello"}], + 400, + { + "error": { + "message": "Unable to retrieve text content of the message: the actual content is a list of content parts.", + "type": "invalid_request_error", + "code": "400", + } + }, + ), ] @@ -72,10 +94,8 @@ def test_error(type, response_status_code, response_content): headers={"Api-Key": "TEST_API_KEY"}, ) - assert ( - response.status_code == response_status_code - and response.json() == response_content - ) + assert response.status_code == response_status_code + assert response.json() == response_content @pytest.mark.parametrize( @@ -96,10 +116,8 @@ def test_streaming_error(type, response_status_code, response_content): headers={"Api-Key": "TEST_API_KEY"}, ) - assert ( - response.status_code == response_status_code - and response.json() == response_content - ) + assert response.status_code == response_status_code + assert response.json() == response_content @pytest.mark.parametrize( @@ -184,4 +202,5 @@ def test_no_api_key(): }, ) - assert response.status_code == 400 and response.json() == API_KEY_IS_MISSING + assert response.status_code == 400 + assert response.json() == API_KEY_IS_MISSING diff --git a/tests/utils/request.py b/tests/utils/request.py deleted file mode 100644 index 514f38c..0000000 --- a/tests/utils/request.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import List - -from aidial_sdk.chat_completion.request import Message, MessageContentTextPart - - -def get_message_text_content(message: Message) -> str: - texts: List[str] = [] - - content = message.content - - if isinstance(content, str): - return content - elif isinstance(content, list): - for part in content: - if isinstance(part, MessageContentTextPart): - texts.append(part.text) - - return "\n".join(texts) diff --git a/tests/utils/tokenization.py b/tests/utils/tokenization.py index 87bde2a..6fc7014 100644 --- a/tests/utils/tokenization.py +++ b/tests/utils/tokenization.py @@ -19,7 +19,6 @@ TruncatePromptResult, TruncatePromptSuccess, ) -from tests.utils.request import get_message_text_content def word_count_string(string: str) -> int: @@ -27,7 +26,7 @@ def word_count_string(string: str) -> int: def word_count_message(message: Message) -> int: - return word_count_string(get_message_text_content(message)) + return word_count_string(message.text) def word_count_request(request: ChatCompletionRequest) -> int: