From 09937ae76fa91892433530c8977f25c4da7ac2fb Mon Sep 17 00:00:00 2001 From: askatasuna Date: Wed, 9 Oct 2024 12:40:39 +0300 Subject: [PATCH] Switched to unit tests --- tests/llm/test_llm.py | 1 + tests/slots/test_slot_partial_extraction.py | 178 +++++++----------- ...group_slots.py => 2_partial_extraction.py} | 4 +- 3 files changed, 69 insertions(+), 114 deletions(-) create mode 100644 tests/llm/test_llm.py rename tutorials/slots/{2_group_slots.py => 2_partial_extraction.py} (97%) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/llm/test_llm.py @@ -0,0 +1 @@ + diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index f95efa3c7..b81c481b6 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -13,133 +13,87 @@ ) from chatsky.slots import RegexpSlot, GroupSlot - +from chatsky.slots.slots import SlotManager, ExtractedValueSlot, ExtractedGroupSlot +from chatsky.core import Message, Context from chatsky.utils.testing import ( check_happy_path, ) -SLOTS = { - "person": GroupSlot( - username=RegexpSlot( - regexp=r"username is ([a-zA-Z]+)", +import pytest + +test_slot = GroupSlot( + person=GroupSlot( + username=RegexpSlot( + regexp=r"([a-z]+_[a-z]+)", match_group_idx=1, ), - email=RegexpSlot( - regexp=r"email is ([a-z@\.A-Z]+)", + email=RegexpSlot( + regexp=r"([a-z]+@[a-z]+\.[a-z]+)", match_group_idx=1, ), + allow_partially_extracted=True + ) + ) + + +extracted_slot_values_turn_1 = { + "person.username": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="test_name", default_value=None ), - "friend": GroupSlot( - first_name=RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )", default_value="default_name"), - last_name=RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+", default_value="default_surname"), - allow_partially_extracted=True, + "person.email": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="test@email.com", default_value=None ), } -script = { - GLOBAL: {TRANSITIONS: [Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart"))]}, - "username_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slot": proc.Extract("person.username")}, - TRANSITIONS: [ - Tr( - dst=("email_flow", "ask"), - cnd=cnd.SlotsExtracted("person.username"), - priority=1.2, - ), - Tr(dst=("username_flow", "repeat_question"), priority=0.8), - ], - }, - "ask": { - RESPONSE: "Write your username (my username is ...):", - }, - "repeat_question": { - RESPONSE: "Please, type your username again (my username is ...):", - }, - }, - "email_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slot": proc.Extract("person.email")}, - TRANSITIONS: [ - Tr( - dst=("friend_flow", "ask"), - cnd=cnd.SlotsExtracted("person.username", "person.email"), - priority=1.2, - ), - Tr(dst=("email_flow", "repeat_question"), priority=0.8), - ], - }, - "ask": { - RESPONSE: "Write your email (my email is ...):", - }, - "repeat_question": { - RESPONSE: "Please, write your email again (my email is ...):", - }, - }, - "friend_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slots": proc.Extract("friend", success_only=False)}, - TRANSITIONS: [ - Tr( - dst=("root", "utter"), - cnd=cnd.SlotsExtracted("friend.first_name", "friend.last_name", mode="any"), - priority=1.2, - ), - Tr(dst=("friend_flow", "repeat_question"), priority=0.8), - ], - }, - "ask": {RESPONSE: "Please, name me one of your friends: (John Doe)"}, - "repeat_question": {RESPONSE: "Please, name me one of your friends again: (John Doe)"}, - }, - "root": { - "start": { - TRANSITIONS: [Tr(dst=("username_flow", "ask"))], - }, - "fallback": { - RESPONSE: "Finishing query", - TRANSITIONS: [Tr(dst=("username_flow", "ask"))], - }, - "utter": { - RESPONSE: rsp.FilledTemplate("Your friend is {friend.first_name} {friend.last_name}"), - TRANSITIONS: [Tr(dst=("root", "utter_alternative"))], - }, - "utter_alternative": { - RESPONSE: "Your username is {person.username}. " "Your email is {person.email}.", - PRE_RESPONSE: {"fill": proc.FillTemplate()}, - }, - }, -} - -HAPPY_PATH = [ - ("hi", "Write your username (my username is ...):"), - ("my username is groot", "Write your email (my email is ...):"), - ( - "my email is groot@gmail.com", - "Please, name me one of your friends: (John Doe)", +extracted_slot_values_turn_2 = { + "person.username": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="new_name", default_value=None ), - ("Bob Page", "Your friend is Bob Page"), - ("ok", "Your username is groot. Your email is groot@gmail.com."), - ("ok", "Finishing query"), - ("again", "Write your username (my username is ...):"), - ("my username is groot", "Write your email (my email is ...):"), - ( - "my email is groot@gmail.com", - "Please, name me one of your friends: (John Doe)", + "person.email": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="test@email.com", default_value=None ), - ("Jim ", "Your friend is Jim Page"), - ("ok", "Your username is groot. Your email is groot@gmail.com."), - ("ok", "Finishing query"), -] +} -# %% -pipeline = Pipeline( - script=script, - start_label=("root", "start"), - fallback_label=("root", "fallback"), - slots=SLOTS, -) +@pytest.fixture(scope="function") +def context_with_request_1(context): + new_ctx = context.model_copy(deep=True) + new_ctx.add_request(Message(text="I am test_name. My email is test@email.com")) + return new_ctx + +@pytest.fixture(scope="function") +def context_with_request_2(context): + context.add_request(Message(text="I am new_name.")) + return context +@pytest.fixture(scope="function") +def empty_slot_manager(): + manager = SlotManager() + manager.set_root_slot(test_slot) + return manager -def test_happy_path(): - check_happy_path(pipeline, HAPPY_PATH, printout=True) # This is a function for automatic tutorial running +@pytest.mark.parametrize( + "slot_name,expected_slot_storage_1,expected_slot_storage_2", + [ + ( + "person", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + username=extracted_slot_values_turn_1["person.username"], + email=extracted_slot_values_turn_1["person.email"], + ) + ), + ExtractedGroupSlot( + person=ExtractedGroupSlot( + username=extracted_slot_values_turn_2["person.username"], + email=extracted_slot_values_turn_2["person.email"], + ) + ), + ), + ], +) +async def test_slot_extraction(slot_name, expected_slot_storage_1, expected_slot_storage_2, empty_slot_manager, context_with_request_1, context_with_request_2): + await empty_slot_manager.extract_slot(slot_name, context_with_request_1, success_only=False) + assert empty_slot_manager.slot_storage == expected_slot_storage_1 + await empty_slot_manager.extract_slot(slot_name, context_with_request_2, success_only=False) + assert empty_slot_manager.slot_storage == expected_slot_storage_2 \ No newline at end of file diff --git a/tutorials/slots/2_group_slots.py b/tutorials/slots/2_partial_extraction.py similarity index 97% rename from tutorials/slots/2_group_slots.py rename to tutorials/slots/2_partial_extraction.py index f2715eb78..92de9022f 100644 --- a/tutorials/slots/2_group_slots.py +++ b/tutorials/slots/2_partial_extraction.py @@ -1,6 +1,6 @@ # %% [markdown] """ -# 2. Group slots usage +# 2. Partial slot extraction This tutorial will show more advanced way of using slots by utilizing `GroupSlot` and different parameters it provides us with. By using Group slots you can extract multiple slots at once if they are placed in one group. @@ -66,7 +66,7 @@ TRANSITIONS: [ Tr( dst=("root", "utter_user"), - cnd=cnd.SlotsExtracted("person.username", "person.email", mode="any"), + cnd=cnd.SlotsExtracted("person", mode="any"), priority=1.2, ), Tr(dst=("user_flow", "repeat_question"), priority=0.8),