Skip to content

Commit

Permalink
fix: review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Oct 9, 2024
1 parent 1a6ff87 commit cfe8d89
Show file tree
Hide file tree
Showing 11 changed files with 54 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EchoApplication(ChatCompletion):
# Generate response with a single choice
with response.create_single_choice() as choice:
# Fill the content of the response with the last user's content
choice.append_content(last_user_message.content or "")
choice.append_content(last_user_message.text)


# DIALApp extends FastAPI to provide a user-friendly interface for routing requests to your applications
Expand Down
17 changes: 17 additions & 0 deletions aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

from typing_extensions import assert_never

from aidial_sdk.chat_completion.enums import Status
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.pydantic_v1 import (
ConstrainedFloat,
ConstrainedInt,
Expand Down Expand Up @@ -85,6 +88,20 @@ class Message(ExtraForbidModel):
tool_call_id: Optional[StrictStr] = None
function_call: Optional[FunctionCall] = None

@property
def text(self) -> str:
def _error_message(actual: str) -> str:
return f"Unable to retrieve text content of the message: the actual content is {actual}."

if self.content is None:
raise InvalidRequestError(_error_message("null or missing"))
elif isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
raise InvalidRequestError(_error_message("a list of content parts"))
else:
assert_never(self.content)


class Addon(ExtraForbidModel):
name: Optional[StrictStr] = None
Expand Down
6 changes: 1 addition & 5 deletions examples/echo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@ async def chat_completion(
# Generate response with a single choice
with response.create_single_choice() as choice:
# Fill the content of the response with the last user's content
choice.append_content(
last_message.content
if isinstance(last_message.content, str)
else ""
)
choice.append_content(last_message.text)

if last_message.custom_content is not None:
for attachment in last_message.custom_content.attachments or []:
Expand Down
2 changes: 1 addition & 1 deletion examples/langchain_rag/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def chat_completion(

with response.create_single_choice() as choice:
message = request.messages[-1]
user_query = message.content or ""
user_query = message.text

file_url = get_last_attachment_url(request.messages)
file_abs_url = urljoin(f"{DIAL_URL}/v1/", file_url)
Expand Down
5 changes: 2 additions & 3 deletions examples/render_text/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ async def chat_completion(self, request: Request, response: Response):
# Create a single choice
with response.create_single_choice() as choice:
# Get the last message content
content = request.messages[-1].content
content_text = content if isinstance(content, str) else ""
content = request.messages[-1].text

# The image may be returned either as base64 string or as URL
# The content specifies the mode of return: 'base64' or 'url'
try:
command, text = content_text.split(",", 1)
command, text = content.split(",", 1)
if command not in ["base64", "url"]:
raise DIALException(
message="The command must be either 'base64' or 'url'",
Expand Down
3 changes: 1 addition & 2 deletions tests/applications/broken_immediately.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from aidial_sdk import HTTPException as DIALException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from tests.utils.request import get_message_text_content


def raise_exception(exception_type: str):
Expand All @@ -28,4 +27,4 @@ class BrokenApplication(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
raise_exception(get_message_text_content(request.messages[0]))
raise_exception(request.messages[0].text)
3 changes: 1 addition & 2 deletions tests/applications/broken_in_runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from tests.applications.broken_immediately import raise_exception
from tests.utils.request import get_message_text_content


class RuntimeBrokenApplication(ChatCompletion):
Expand All @@ -18,4 +17,4 @@ async def chat_completion(
choice.append_content("Test content")
await response.aflush()

raise_exception(get_message_text_content(request.messages[0]))
raise_exception(request.messages[0].text)
3 changes: 1 addition & 2 deletions tests/applications/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TruncatePromptRequest,
TruncatePromptResponse,
)
from tests.utils.request import get_message_text_content
from tests.utils.tokenization import (
default_truncate_prompt,
make_batched_tokenize,
Expand All @@ -28,7 +27,7 @@ async def chat_completion(
response.set_response_id("test_id")
response.set_created(0)

content = get_message_text_content(request.messages[-1])
content = request.messages[-1].text

with response.create_single_choice() as choice:
choice.append_content(content)
Expand Down
37 changes: 28 additions & 9 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@
}
},
),
(
None,
400,
{
"error": {
"message": "Unable to retrieve text content of the message: the actual content is null or missing.",
"type": "invalid_request_error",
"code": "400",
}
},
),
(
[{"type": "text", "text": "hello"}],
400,
{
"error": {
"message": "Unable to retrieve text content of the message: the actual content is a list of content parts.",
"type": "invalid_request_error",
"code": "400",
}
},
),
]


Expand All @@ -72,10 +94,8 @@ def test_error(type, response_status_code, response_content):
headers={"Api-Key": "TEST_API_KEY"},
)

assert (
response.status_code == response_status_code
and response.json() == response_content
)
assert response.status_code == response_status_code
assert response.json() == response_content


@pytest.mark.parametrize(
Expand All @@ -96,10 +116,8 @@ def test_streaming_error(type, response_status_code, response_content):
headers={"Api-Key": "TEST_API_KEY"},
)

assert (
response.status_code == response_status_code
and response.json() == response_content
)
assert response.status_code == response_status_code
assert response.json() == response_content


@pytest.mark.parametrize(
Expand Down Expand Up @@ -184,4 +202,5 @@ def test_no_api_key():
},
)

assert response.status_code == 400 and response.json() == API_KEY_IS_MISSING
assert response.status_code == 400
assert response.json() == API_KEY_IS_MISSING
18 changes: 0 additions & 18 deletions tests/utils/request.py

This file was deleted.

3 changes: 1 addition & 2 deletions tests/utils/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
TruncatePromptResult,
TruncatePromptSuccess,
)
from tests.utils.request import get_message_text_content


def word_count_string(string: str) -> int:
return len(string.split())


def word_count_message(message: Message) -> int:
return word_count_string(get_message_text_content(message))
return word_count_string(message.text)


def word_count_request(request: ChatCompletionRequest) -> int:
Expand Down

0 comments on commit cfe8d89

Please sign in to comment.