Skip to content

Commit

Permalink
Switched to unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Oct 9, 2024
1 parent 33f05d0 commit 09937ae
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 114 deletions.
1 change: 1 addition & 0 deletions tests/llm/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

178 changes: 66 additions & 112 deletions tests/slots/test_slot_partial_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,133 +13,87 @@
)

from chatsky.slots import RegexpSlot, GroupSlot

from chatsky.slots.slots import SlotManager, ExtractedValueSlot, ExtractedGroupSlot
from chatsky.core import Message, Context

from chatsky.utils.testing import (
check_happy_path,
)

SLOTS = {
"person": GroupSlot(
username=RegexpSlot(
regexp=r"username is ([a-zA-Z]+)",
import pytest

test_slot = GroupSlot(
person=GroupSlot(
username=RegexpSlot(
regexp=r"([a-z]+_[a-z]+)",
match_group_idx=1,
),
email=RegexpSlot(
regexp=r"email is ([a-z@\.A-Z]+)",
email=RegexpSlot(
regexp=r"([a-z]+@[a-z]+\.[a-z]+)",
match_group_idx=1,
),
allow_partially_extracted=True
)
)


extracted_slot_values_turn_1 = {
"person.username": ExtractedValueSlot.model_construct(
is_slot_extracted=True, extracted_value="test_name", default_value=None
),
"friend": GroupSlot(
first_name=RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )", default_value="default_name"),
last_name=RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+", default_value="default_surname"),
allow_partially_extracted=True,
"person.email": ExtractedValueSlot.model_construct(
is_slot_extracted=True, extracted_value="[email protected]", default_value=None
),
}

script = {
GLOBAL: {TRANSITIONS: [Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart"))]},
"username_flow": {
LOCAL: {
PRE_TRANSITION: {"get_slot": proc.Extract("person.username")},
TRANSITIONS: [
Tr(
dst=("email_flow", "ask"),
cnd=cnd.SlotsExtracted("person.username"),
priority=1.2,
),
Tr(dst=("username_flow", "repeat_question"), priority=0.8),
],
},
"ask": {
RESPONSE: "Write your username (my username is ...):",
},
"repeat_question": {
RESPONSE: "Please, type your username again (my username is ...):",
},
},
"email_flow": {
LOCAL: {
PRE_TRANSITION: {"get_slot": proc.Extract("person.email")},
TRANSITIONS: [
Tr(
dst=("friend_flow", "ask"),
cnd=cnd.SlotsExtracted("person.username", "person.email"),
priority=1.2,
),
Tr(dst=("email_flow", "repeat_question"), priority=0.8),
],
},
"ask": {
RESPONSE: "Write your email (my email is ...):",
},
"repeat_question": {
RESPONSE: "Please, write your email again (my email is ...):",
},
},
"friend_flow": {
LOCAL: {
PRE_TRANSITION: {"get_slots": proc.Extract("friend", success_only=False)},
TRANSITIONS: [
Tr(
dst=("root", "utter"),
cnd=cnd.SlotsExtracted("friend.first_name", "friend.last_name", mode="any"),
priority=1.2,
),
Tr(dst=("friend_flow", "repeat_question"), priority=0.8),
],
},
"ask": {RESPONSE: "Please, name me one of your friends: (John Doe)"},
"repeat_question": {RESPONSE: "Please, name me one of your friends again: (John Doe)"},
},
"root": {
"start": {
TRANSITIONS: [Tr(dst=("username_flow", "ask"))],
},
"fallback": {
RESPONSE: "Finishing query",
TRANSITIONS: [Tr(dst=("username_flow", "ask"))],
},
"utter": {
RESPONSE: rsp.FilledTemplate("Your friend is {friend.first_name} {friend.last_name}"),
TRANSITIONS: [Tr(dst=("root", "utter_alternative"))],
},
"utter_alternative": {
RESPONSE: "Your username is {person.username}. " "Your email is {person.email}.",
PRE_RESPONSE: {"fill": proc.FillTemplate()},
},
},
}

HAPPY_PATH = [
("hi", "Write your username (my username is ...):"),
("my username is groot", "Write your email (my email is ...):"),
(
"my email is [email protected]",
"Please, name me one of your friends: (John Doe)",
extracted_slot_values_turn_2 = {
"person.username": ExtractedValueSlot.model_construct(
is_slot_extracted=True, extracted_value="new_name", default_value=None
),
("Bob Page", "Your friend is Bob Page"),
("ok", "Your username is groot. Your email is [email protected]."),
("ok", "Finishing query"),
("again", "Write your username (my username is ...):"),
("my username is groot", "Write your email (my email is ...):"),
(
"my email is [email protected]",
"Please, name me one of your friends: (John Doe)",
"person.email": ExtractedValueSlot.model_construct(
is_slot_extracted=True, extracted_value="[email protected]", default_value=None
),
("Jim ", "Your friend is Jim Page"),
("ok", "Your username is groot. Your email is [email protected]."),
("ok", "Finishing query"),
]
}

# %%
pipeline = Pipeline(
script=script,
start_label=("root", "start"),
fallback_label=("root", "fallback"),
slots=SLOTS,
)
@pytest.fixture(scope="function")
def context_with_request_1(context):
new_ctx = context.model_copy(deep=True)
new_ctx.add_request(Message(text="I am test_name. My email is [email protected]"))
return new_ctx

@pytest.fixture(scope="function")
def context_with_request_2(context):
context.add_request(Message(text="I am new_name."))
return context

@pytest.fixture(scope="function")
def empty_slot_manager():
manager = SlotManager()
manager.set_root_slot(test_slot)
return manager

def test_happy_path():
check_happy_path(pipeline, HAPPY_PATH, printout=True) # This is a function for automatic tutorial running
@pytest.mark.parametrize(
"slot_name,expected_slot_storage_1,expected_slot_storage_2",
[
(
"person",
ExtractedGroupSlot(
person=ExtractedGroupSlot(
username=extracted_slot_values_turn_1["person.username"],
email=extracted_slot_values_turn_1["person.email"],
)
),
ExtractedGroupSlot(
person=ExtractedGroupSlot(
username=extracted_slot_values_turn_2["person.username"],
email=extracted_slot_values_turn_2["person.email"],
)
),
),
],
)
async def test_slot_extraction(slot_name, expected_slot_storage_1, expected_slot_storage_2, empty_slot_manager, context_with_request_1, context_with_request_2):
await empty_slot_manager.extract_slot(slot_name, context_with_request_1, success_only=False)
assert empty_slot_manager.slot_storage == expected_slot_storage_1
await empty_slot_manager.extract_slot(slot_name, context_with_request_2, success_only=False)
assert empty_slot_manager.slot_storage == expected_slot_storage_2
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# %% [markdown]
"""
# 2. Group slots usage
# 2. Partial slot extraction
This tutorial will show more advanced way of using slots by utilizing `GroupSlot` and different parameters it provides us with.
By using Group slots you can extract multiple slots at once if they are placed in one group.
Expand Down Expand Up @@ -66,7 +66,7 @@
TRANSITIONS: [
Tr(
dst=("root", "utter_user"),
cnd=cnd.SlotsExtracted("person.username", "person.email", mode="any"),
cnd=cnd.SlotsExtracted("person", mode="any"),
priority=1.2,
),
Tr(dst=("user_flow", "repeat_question"), priority=0.8),
Expand Down

0 comments on commit 09937ae

Please sign in to comment.