-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
33f05d0
commit 09937ae
Showing
3 changed files
with
69 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,133 +13,87 @@ | |
) | ||
|
||
from chatsky.slots import RegexpSlot, GroupSlot | ||
|
||
from chatsky.slots.slots import SlotManager, ExtractedValueSlot, ExtractedGroupSlot | ||
from chatsky.core import Message, Context | ||
|
||
from chatsky.utils.testing import ( | ||
check_happy_path, | ||
) | ||
|
||
SLOTS = { | ||
"person": GroupSlot( | ||
username=RegexpSlot( | ||
regexp=r"username is ([a-zA-Z]+)", | ||
import pytest | ||
|
||
test_slot = GroupSlot( | ||
person=GroupSlot( | ||
username=RegexpSlot( | ||
regexp=r"([a-z]+_[a-z]+)", | ||
match_group_idx=1, | ||
), | ||
email=RegexpSlot( | ||
regexp=r"email is ([a-z@\.A-Z]+)", | ||
email=RegexpSlot( | ||
regexp=r"([a-z]+@[a-z]+\.[a-z]+)", | ||
match_group_idx=1, | ||
), | ||
allow_partially_extracted=True | ||
) | ||
) | ||
|
||
|
||
extracted_slot_values_turn_1 = { | ||
"person.username": ExtractedValueSlot.model_construct( | ||
is_slot_extracted=True, extracted_value="test_name", default_value=None | ||
), | ||
"friend": GroupSlot( | ||
first_name=RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )", default_value="default_name"), | ||
last_name=RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+", default_value="default_surname"), | ||
allow_partially_extracted=True, | ||
"person.email": ExtractedValueSlot.model_construct( | ||
is_slot_extracted=True, extracted_value="[email protected]", default_value=None | ||
), | ||
} | ||
|
||
script = { | ||
GLOBAL: {TRANSITIONS: [Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart"))]}, | ||
"username_flow": { | ||
LOCAL: { | ||
PRE_TRANSITION: {"get_slot": proc.Extract("person.username")}, | ||
TRANSITIONS: [ | ||
Tr( | ||
dst=("email_flow", "ask"), | ||
cnd=cnd.SlotsExtracted("person.username"), | ||
priority=1.2, | ||
), | ||
Tr(dst=("username_flow", "repeat_question"), priority=0.8), | ||
], | ||
}, | ||
"ask": { | ||
RESPONSE: "Write your username (my username is ...):", | ||
}, | ||
"repeat_question": { | ||
RESPONSE: "Please, type your username again (my username is ...):", | ||
}, | ||
}, | ||
"email_flow": { | ||
LOCAL: { | ||
PRE_TRANSITION: {"get_slot": proc.Extract("person.email")}, | ||
TRANSITIONS: [ | ||
Tr( | ||
dst=("friend_flow", "ask"), | ||
cnd=cnd.SlotsExtracted("person.username", "person.email"), | ||
priority=1.2, | ||
), | ||
Tr(dst=("email_flow", "repeat_question"), priority=0.8), | ||
], | ||
}, | ||
"ask": { | ||
RESPONSE: "Write your email (my email is ...):", | ||
}, | ||
"repeat_question": { | ||
RESPONSE: "Please, write your email again (my email is ...):", | ||
}, | ||
}, | ||
"friend_flow": { | ||
LOCAL: { | ||
PRE_TRANSITION: {"get_slots": proc.Extract("friend", success_only=False)}, | ||
TRANSITIONS: [ | ||
Tr( | ||
dst=("root", "utter"), | ||
cnd=cnd.SlotsExtracted("friend.first_name", "friend.last_name", mode="any"), | ||
priority=1.2, | ||
), | ||
Tr(dst=("friend_flow", "repeat_question"), priority=0.8), | ||
], | ||
}, | ||
"ask": {RESPONSE: "Please, name me one of your friends: (John Doe)"}, | ||
"repeat_question": {RESPONSE: "Please, name me one of your friends again: (John Doe)"}, | ||
}, | ||
"root": { | ||
"start": { | ||
TRANSITIONS: [Tr(dst=("username_flow", "ask"))], | ||
}, | ||
"fallback": { | ||
RESPONSE: "Finishing query", | ||
TRANSITIONS: [Tr(dst=("username_flow", "ask"))], | ||
}, | ||
"utter": { | ||
RESPONSE: rsp.FilledTemplate("Your friend is {friend.first_name} {friend.last_name}"), | ||
TRANSITIONS: [Tr(dst=("root", "utter_alternative"))], | ||
}, | ||
"utter_alternative": { | ||
RESPONSE: "Your username is {person.username}. " "Your email is {person.email}.", | ||
PRE_RESPONSE: {"fill": proc.FillTemplate()}, | ||
}, | ||
}, | ||
} | ||
|
||
HAPPY_PATH = [ | ||
("hi", "Write your username (my username is ...):"), | ||
("my username is groot", "Write your email (my email is ...):"), | ||
( | ||
"my email is [email protected]", | ||
"Please, name me one of your friends: (John Doe)", | ||
extracted_slot_values_turn_2 = { | ||
"person.username": ExtractedValueSlot.model_construct( | ||
is_slot_extracted=True, extracted_value="new_name", default_value=None | ||
), | ||
("Bob Page", "Your friend is Bob Page"), | ||
("ok", "Your username is groot. Your email is [email protected]."), | ||
("ok", "Finishing query"), | ||
("again", "Write your username (my username is ...):"), | ||
("my username is groot", "Write your email (my email is ...):"), | ||
( | ||
"my email is [email protected]", | ||
"Please, name me one of your friends: (John Doe)", | ||
"person.email": ExtractedValueSlot.model_construct( | ||
is_slot_extracted=True, extracted_value="[email protected]", default_value=None | ||
), | ||
("Jim ", "Your friend is Jim Page"), | ||
("ok", "Your username is groot. Your email is [email protected]."), | ||
("ok", "Finishing query"), | ||
] | ||
} | ||
|
||
# %% | ||
pipeline = Pipeline( | ||
script=script, | ||
start_label=("root", "start"), | ||
fallback_label=("root", "fallback"), | ||
slots=SLOTS, | ||
) | ||
@pytest.fixture(scope="function") | ||
def context_with_request_1(context): | ||
new_ctx = context.model_copy(deep=True) | ||
new_ctx.add_request(Message(text="I am test_name. My email is [email protected]")) | ||
return new_ctx | ||
|
||
@pytest.fixture(scope="function") | ||
def context_with_request_2(context): | ||
context.add_request(Message(text="I am new_name.")) | ||
return context | ||
|
||
@pytest.fixture(scope="function") | ||
def empty_slot_manager(): | ||
manager = SlotManager() | ||
manager.set_root_slot(test_slot) | ||
return manager | ||
|
||
def test_happy_path(): | ||
check_happy_path(pipeline, HAPPY_PATH, printout=True) # This is a function for automatic tutorial running | ||
@pytest.mark.parametrize( | ||
"slot_name,expected_slot_storage_1,expected_slot_storage_2", | ||
[ | ||
( | ||
"person", | ||
ExtractedGroupSlot( | ||
person=ExtractedGroupSlot( | ||
username=extracted_slot_values_turn_1["person.username"], | ||
email=extracted_slot_values_turn_1["person.email"], | ||
) | ||
), | ||
ExtractedGroupSlot( | ||
person=ExtractedGroupSlot( | ||
username=extracted_slot_values_turn_2["person.username"], | ||
email=extracted_slot_values_turn_2["person.email"], | ||
) | ||
), | ||
), | ||
], | ||
) | ||
async def test_slot_extraction(slot_name, expected_slot_storage_1, expected_slot_storage_2, empty_slot_manager, context_with_request_1, context_with_request_2): | ||
await empty_slot_manager.extract_slot(slot_name, context_with_request_1, success_only=False) | ||
assert empty_slot_manager.slot_storage == expected_slot_storage_1 | ||
await empty_slot_manager.extract_slot(slot_name, context_with_request_2, success_only=False) | ||
assert empty_slot_manager.slot_storage == expected_slot_storage_2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters