Skip to content

Commit

Permalink
feat: expected type of the request user message content with a list o…
Browse files Browse the repository at this point in the history
…f content part
  • Loading branch information
adubovik committed Oct 1, 2024
1 parent 6196b88 commit c5c5d29
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 14 deletions.
20 changes: 19 additions & 1 deletion aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,27 @@ class Role(str, Enum):
TOOL = "tool"


class ImageURL(ExtraForbidModel):
url: StrictStr
detail: Literal["auto", "low", "high"]


class MessageContentImagePart(ExtraForbidModel):
type: Literal["image_url"]
image_url: ImageURL


class MessageContentTextPart(ExtraForbidModel):
type: Literal["text"]
text: StrictStr


MessageContentPart = Union[MessageContentTextPart, MessageContentImagePart]


class Message(ExtraForbidModel):
role: Role
content: Optional[StrictStr] = None
content: Optional[Union[StrictStr, List[MessageContentPart]]] = None
custom_content: Optional[CustomContent] = None
name: Optional[StrictStr] = None
tool_calls: Optional[List[ToolCall]] = None
Expand Down
16 changes: 9 additions & 7 deletions examples/echo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ async def chat_completion(
self, request: Request, response: Response
) -> None:
# Get last message (the newest) from the history
last_user_message = request.messages[-1]
last_message = request.messages[-1]

# Generate response with a single choice
with response.create_single_choice() as choice:
# Fill the content of the response with the last user's content
choice.append_content(last_user_message.content or "")

if last_user_message.custom_content is not None:
for attachment in (
last_user_message.custom_content.attachments or []
):
choice.append_content(
last_message.content
if isinstance(last_message.content, str)
else ""
)

if last_message.custom_content is not None:
for attachment in last_message.custom_content.attachments or []:
# Add the same attachment to the response
choice.add_attachment(**attachment.dict())

Expand Down
5 changes: 3 additions & 2 deletions examples/render_text/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ async def chat_completion(self, request: Request, response: Response):
# Create a single choice
with response.create_single_choice() as choice:
# Get the last message content
content = request.messages[-1].content or ""
content = request.messages[-1].content
content_text = content if isinstance(content, str) else ""

# The image may be returned either as base64 string or as URL
# The content specifies the mode of return: 'base64' or 'url'
try:
command, text = content.split(",", 1)
command, text = content_text.split(",", 1)
if command not in ["base64", "url"]:
raise DIALException(
message="The command must be either 'base64' or 'url'",
Expand Down
3 changes: 2 additions & 1 deletion tests/applications/broken_immediately.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aidial_sdk import HTTPException as DIALException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from tests.utils.request import get_message_text_content


def raise_exception(exception_type: str):
Expand All @@ -27,4 +28,4 @@ class BrokenApplication(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
raise_exception(request.messages[0].content or "")
raise_exception(get_message_text_content(request.messages[0]))
3 changes: 2 additions & 1 deletion tests/applications/broken_in_runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from tests.applications.broken_immediately import raise_exception
from tests.utils.request import get_message_text_content


class RuntimeBrokenApplication(ChatCompletion):
Expand All @@ -17,4 +18,4 @@ async def chat_completion(
choice.append_content("Test content")
await response.aflush()

raise_exception(request.messages[0].content or "")
raise_exception(get_message_text_content(request.messages[0]))
3 changes: 2 additions & 1 deletion tests/applications/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TruncatePromptRequest,
TruncatePromptResponse,
)
from tests.utils.request import get_message_text_content
from tests.utils.tokenization import (
default_truncate_prompt,
make_batched_tokenize,
Expand All @@ -27,7 +28,7 @@ async def chat_completion(
response.set_response_id("test_id")
response.set_created(0)

content = request.messages[-1].content or ""
content = get_message_text_content(request.messages[-1])

with response.create_single_choice() as choice:
choice.append_content(content)
Expand Down
18 changes: 18 additions & 0 deletions tests/utils/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import List

from aidial_sdk.chat_completion.request import Message, MessageContentTextPart


def get_message_text_content(message: Message) -> str:
texts: List[str] = []

content = message.content

if isinstance(content, str):
return content
elif isinstance(content, list):
for part in content:
if isinstance(part, MessageContentTextPart):
texts.append(part.text)

return "\n".join(texts)
3 changes: 2 additions & 1 deletion tests/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
TruncatePromptResult,
TruncatePromptSuccess,
)
from tests.utils.request import get_message_text_content


def word_count_string(string: str) -> int:
return len(string.split())


def word_count_message(message: Message) -> int:
return word_count_string(message.content or "")
return word_count_string(get_message_text_content(message))


def word_count_request(request: ChatCompletionRequest) -> int:
Expand Down

0 comments on commit c5c5d29

Please sign in to comment.