From 4293e64177cd69ac054809cc115bf101eb63ff90 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Fri, 12 Jul 2024 10:46:27 +0100 Subject: [PATCH] feat: inherit all Enums from str to make JSON serialization possible --- aidial_sdk/chat_completion/enums.py | 4 ++-- aidial_sdk/chat_completion/request.py | 2 +- tests/test_serialization.py | 19 +++++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) create mode 100644 tests/test_serialization.py diff --git a/aidial_sdk/chat_completion/enums.py b/aidial_sdk/chat_completion/enums.py index 2ceeb30..481ffc4 100644 --- a/aidial_sdk/chat_completion/enums.py +++ b/aidial_sdk/chat_completion/enums.py @@ -1,7 +1,7 @@ from enum import Enum -class FinishReason(Enum): +class FinishReason(str, Enum): STOP = "stop" LENGTH = "length" FUNCTION_CALL = "function_call" @@ -9,6 +9,6 @@ class FinishReason(Enum): CONTENT_FILTER = "content_filter" -class Status(Enum): +class Status(str, Enum): COMPLETED = "completed" FAILED = "failed" diff --git a/aidial_sdk/chat_completion/request.py b/aidial_sdk/chat_completion/request.py index a7a4f54..706e94a 100644 --- a/aidial_sdk/chat_completion/request.py +++ b/aidial_sdk/chat_completion/request.py @@ -48,7 +48,7 @@ class ToolCall(ExtraForbidModel): function: FunctionCall -class Role(Enum): +class Role(str, Enum): SYSTEM = "system" USER = "user" ASSISTANT = "assistant" diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..3c3907d --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,19 @@ +import json + +from aidial_sdk.chat_completion import Message, Role + + +def test_message_ser(): + msg_obj = Message(role=Role.SYSTEM, content="test") + actual_dict = msg_obj.dict(exclude_none=True) + expected_dict = {"role": "system", "content": "test"} + + assert json.loads(json.dumps(actual_dict)) == expected_dict + + +def test_message_deser(): + msg_dict = {"role": "system", "content": "test"} + actual_obj = Message.parse_raw(json.dumps(msg_dict)) + expected_obj = Message(role=Role.SYSTEM, content="test") + + assert actual_obj == expected_obj