diff --git a/common/dff_api_v1/integration/condition.py b/common/dff_api_v1/integration/condition.py new file mode 100644 index 0000000000..d069279fa3 --- /dev/null +++ b/common/dff_api_v1/integration/condition.py @@ -0,0 +1,289 @@ +import logging +import re + +from nltk.stem import WordNetLemmatizer + +from dff.script import Context +from dff.pipeline import Pipeline + +import common.greeting as common_greeting +import common.utils as common_utils +import common.universal_templates as universal_templates +import common.dff_api_v1.integration.context as int_ctx +from common.acknowledgements import GENERAL_ACKNOWLEDGEMENTS +from common.constants import CAN_CONTINUE_SCENARIO, CAN_NOT_CONTINUE +from .facts_utils import provide_facts_request + +logger = logging.getLogger(__name__) + +wnl = WordNetLemmatizer() + + +# vars is described in README.md + + +def was_clarification_request(ctx: Context, _) -> bool: + flag = ctx.misc["agent"]["clarification_request_flag"] if not ctx.validation else False + logger.debug(f"was_clarification_request = {flag}") + return bool(flag) + + +def is_opinion_request(ctx: Context, pipeline: Pipeline) -> bool: + flag = common_utils.is_opinion_request(int_ctx.get_last_human_utterance(ctx, pipeline)) + logger.debug(f"is_opinion_request = {flag}") + return bool(flag) + + +def is_opinion_expression(ctx: Context, pipeline: Pipeline) -> bool: + flag = common_utils.is_opinion_expression(int_ctx.get_last_human_utterance(ctx, pipeline)) + logger.debug(f"is_opinion_expression = {flag}") + return bool(flag) + + +def is_previous_turn_dff_suspended(ctx: Context, _) -> bool: + flag = ctx.misc["agent"].get("previous_turn_dff_suspended", False) if not ctx.validation else False + logger.debug(f"is_previous_turn_dff_suspended = {flag}") + return bool(flag) + + +def is_current_turn_dff_suspended(ctx: Context, _) -> bool: + flag = ctx.misc["agent"].get("current_turn_dff_suspended", False) if not ctx.validation else False + logger.debug(f"is_current_turn_dff_suspended = {flag}") + return bool(flag) + + +def is_switch_topic(ctx: Context, pipeline: Pipeline) -> bool: + flag = universal_templates.is_switch_topic(int_ctx.get_last_human_utterance(ctx, pipeline)) + logger.debug(f"is_switch_topic = {flag}") + return bool(flag) + + +def is_question(ctx: Context, pipeline: Pipeline) -> bool: + text = int_ctx.get_last_human_utterance(ctx, pipeline)["text"] + flag = common_utils.is_question(text) + logger.debug(f"is_question = {flag}") + return bool(flag) + + +def is_lets_chat_about_topic_human_initiative(ctx: Context, pipeline: Pipeline) -> bool: + flag = universal_templates.if_chat_about_particular_topic( + int_ctx.get_last_human_utterance(ctx, pipeline), int_ctx.get_last_bot_utterance(ctx, pipeline) + ) + logger.debug(f"is_lets_chat_about_topic_human_initiative = {flag}") + return bool(flag) + + +def is_lets_chat_about_topic(ctx: Context, pipeline: Pipeline) -> bool: + flag = is_lets_chat_about_topic_human_initiative(ctx, pipeline) + + last_human_uttr = int_ctx.get_last_human_utterance(ctx, pipeline) + last_bot_uttr_text = int_ctx.get_last_bot_utterance(ctx, pipeline)["text"] + is_bot_initiative = bool(re.search(universal_templates.COMPILE_WHAT_TO_TALK_ABOUT, last_bot_uttr_text)) + flag = flag or (is_bot_initiative and not common_utils.is_no(last_human_uttr)) + logger.debug(f"is_lets_chat_about_topic = {flag}") + return bool(flag) + + +def is_begin_of_dialog(ctx: Context, pipeline: Pipeline, begin_dialog_n=10) -> bool: + flag = int_ctx.get_human_utter_index(ctx, pipeline) < begin_dialog_n + logger.debug(f"is_begin_of_dialog = {flag}") + return bool(flag) + + +def is_interrupted(ctx: Context, pipeline: Pipeline) -> bool: + flag = ( + int_ctx.get_human_utter_index(ctx, pipeline) - int_ctx.get_previous_human_utter_index(ctx, pipeline) + ) != 1 and not was_clarification_request(ctx, pipeline) + logger.debug(f"is_interrupted = {flag}") + return bool(flag) + + +def is_long_interrupted(ctx: Context, pipeline: Pipeline, how_long=3) -> bool: + flag = ( + int_ctx.get_human_utter_index(ctx, pipeline) - int_ctx.get_previous_human_utter_index(ctx, pipeline) + ) > how_long and not was_clarification_request(ctx, pipeline) + logger.debug(f"is_long_interrupted = {flag}") + return bool(flag) + + +def is_new_human_entity(ctx: Context, pipeline: Pipeline) -> bool: + new_entities = int_ctx.get_new_human_labeled_noun_phrase(ctx, pipeline) + flag = bool(new_entities) + logger.debug(f"is_new_human_entity = {flag}") + return bool(flag) + + +def is_last_state(ctx: Context, pipeline: Pipeline, state) -> bool: + flag = False + if not ctx.validation: + history = list(int_ctx.get_history(ctx, pipeline).items()) + if history: + history_sorted = sorted(history, key=lambda x: x[0]) + last_state = history_sorted[-1][1] + if last_state == state: + flag = True + return bool(flag) + + +def is_first_time_of_state(ctx: Context, pipeline: Pipeline, state) -> bool: + flag = state not in list(int_ctx.get_history(ctx, pipeline).values()) + logger.debug(f"is_first_time_of_state {state} = {flag}") + return bool(flag) + + +def if_was_prev_active(ctx: Context, pipeline: Pipeline) -> bool: + flag = False + skill_uttr_indices = set(int_ctx.get_history(ctx, pipeline).keys()) + if not ctx.validation: + human_uttr_index = str(ctx.misc["agent"]["human_utter_index"] - 1) + if human_uttr_index in skill_uttr_indices: + flag = True + return bool(flag) + + +def is_plural(word) -> bool: + lemma = wnl.lemmatize(word, "n") + plural = True if word is not lemma else False + return plural + + +def is_first_our_response(ctx: Context, pipeline: Pipeline) -> bool: + flag = len(list(int_ctx.get_history(ctx, pipeline).values())) == 0 + logger.debug(f"is_first_our_response = {flag}") + return bool(flag) + + +def is_no_human_abandon(ctx: Context, pipeline: Pipeline) -> bool: + """Is dialog breakdown in human utterance or no. Uses MIDAS hold/abandon classes.""" + midas_classes = common_utils.get_intents(int_ctx.get_last_human_utterance(ctx, pipeline), which="midas") + if "abandon" not in midas_classes: + return True + return False + + +def no_special_switch_off_requests(ctx: Context, pipeline: Pipeline) -> bool: + """Function to determine if + - user didn't asked to switch topic, + - user didn't ask to talk about something particular, + - user didn't requested high priority intents (like what_is_your_name) + """ + intents_by_catcher = common_utils.get_intents( + int_ctx.get_last_human_utterance(ctx, pipeline), probs=False, which="intent_catcher" + ) + is_high_priority_intent = any([intent not in common_utils.service_intents for intent in intents_by_catcher]) + is_switch = is_switch_topic(ctx, pipeline) + is_lets_chat = is_lets_chat_about_topic_human_initiative(ctx, pipeline) + + if not (is_high_priority_intent or is_switch or is_lets_chat): + return True + return False + + +def no_requests(ctx: Context, pipeline: Pipeline) -> bool: + """Function to determine if + - user didn't asked to switch topic, + - user didn't ask to talk about something particular, + - user didn't requested high priority intents (like what_is_your_name) + - user didn't requested any special intents + - user didn't ask questions + """ + contain_no_special_requests = no_special_switch_off_requests(ctx, pipeline) + + request_intents = [ + "opinion_request", + "topic_switching", + "lets_chat_about", + "what_are_you_talking_about", + "Information_RequestIntent", + "Topic_SwitchIntent", + "Opinion_RequestIntent", + ] + intents = common_utils.get_intents(int_ctx.get_last_human_utterance(ctx, pipeline), which="all") + is_not_request_intent = all([intent not in request_intents for intent in intents]) + is_no_question = "?" not in int_ctx.get_last_human_utterance(ctx, pipeline)["text"] + + if contain_no_special_requests and is_not_request_intent and is_no_question: + return True + return False + + +def is_yes_vars(ctx: Context, pipeline: Pipeline) -> bool: + flag = True + flag = flag and common_utils.is_yes(int_ctx.get_last_human_utterance(ctx, pipeline)) + return bool(flag) + + +def is_no_vars(ctx: Context, pipeline: Pipeline) -> bool: + flag = True + flag = flag and common_utils.is_no(int_ctx.get_last_human_utterance(ctx, pipeline)) + return bool(flag) + + +def is_do_not_know_vars(ctx: Context, pipeline: Pipeline) -> bool: + flag = True + flag = flag and common_utils.is_donot_know(int_ctx.get_last_human_utterance(ctx, pipeline)) + return bool(flag) + + +def is_passive_user(ctx: Context, pipeline: Pipeline, passive_threshold=3, history_len=2) -> bool: + """Check history_len last human utterances on the number of tokens. + If number of tokens in ALL history_len uterances is less or equal than the given threshold, + then consider user passive - return True. + """ + user_utterances = int_ctx.get_human_utterances(ctx, pipeline)[-history_len:] + user_utterances = [utt["text"] for utt in user_utterances] + + uttrs_lens = [len(uttr.split()) <= passive_threshold for uttr in user_utterances] + if all(uttrs_lens): + return True + return False + + +def get_not_used_and_save_sentiment_acknowledgement(ctx: Context, pipeline: Pipeline, sentiment=None, lang="EN"): + if sentiment is None: + sentiment = int_ctx.get_human_sentiment(ctx, pipeline) + if is_yes_vars(ctx, pipeline) or is_no_vars(ctx, pipeline): + sentiment = "neutral" + + shared_memory = int_ctx.get_shared_memory(ctx, pipeline) + last_acknowledgements = shared_memory.get("last_acknowledgements", []) + + ack = common_utils.get_not_used_template( + used_templates=last_acknowledgements, all_templates=GENERAL_ACKNOWLEDGEMENTS[lang][sentiment] + ) + + used_acks = last_acknowledgements + [ack] + int_ctx.save_to_shared_memory(ctx, pipeline, last_acknowledgements=used_acks[-2:]) + return ack + + +def set_conf_and_can_cont_by_universal_policy(ctx: Context, pipeline: Pipeline): + DIALOG_BEGINNING_START_CONFIDENCE = 0.98 + DIALOG_BEGINNING_CONTINUE_CONFIDENCE = 0.9 + DIALOG_BEGINNING_SHORT_ANSWER_CONFIDENCE = 0.98 + MIDDLE_DIALOG_START_CONFIDENCE = 0.7 + + if not is_begin_of_dialog(ctx, pipeline, begin_dialog_n=10): + confidence = 0.0 + can_continue_flag = CAN_NOT_CONTINUE + elif is_first_our_response(ctx, pipeline): + confidence = DIALOG_BEGINNING_START_CONFIDENCE + can_continue_flag = CAN_CONTINUE_SCENARIO + elif not is_interrupted(ctx, pipeline) and common_greeting.dont_tell_you_answer( + int_ctx.get_last_human_utterance(ctx, pipeline) + ): + confidence = DIALOG_BEGINNING_SHORT_ANSWER_CONFIDENCE + can_continue_flag = CAN_CONTINUE_SCENARIO + elif not is_interrupted(ctx, pipeline): + confidence = DIALOG_BEGINNING_CONTINUE_CONFIDENCE + can_continue_flag = CAN_CONTINUE_SCENARIO + else: + confidence = MIDDLE_DIALOG_START_CONFIDENCE + can_continue_flag = CAN_CONTINUE_SCENARIO + + int_ctx.set_can_continue(ctx, pipeline, can_continue_flag) + int_ctx.set_confidence(ctx, pipeline, confidence) + + +def facts(ctx, pipeline): + return provide_facts_request(ctx, pipeline) diff --git a/common/dff_api_v1/integration/context.py b/common/dff_api_v1/integration/context.py new file mode 100644 index 0000000000..a05816ef8f --- /dev/null +++ b/common/dff_api_v1/integration/context.py @@ -0,0 +1,333 @@ +import logging +import os +import random + +from dff.script import Context +from dff.pipeline import Pipeline + +import common.constants as common_constants +import common.link as common_link +import common.news as common_news +import common.utils as common_utils + +logger = logging.getLogger(__name__) +SERVICE_NAME = os.getenv("SERVICE_NAME") + + +NEWS_API_ANNOTATOR_URL = os.getenv("NEWS_API_ANNOTATOR_URL") + + +def get_new_human_labeled_noun_phrase(ctx: Context, pipeline: Pipeline) -> list: + return ( + [] + if ctx.validation + else ( + get_last_human_utterance(ctx, pipeline).get("annotations", {}).get("cobot_entities", {}).get("entities", []) + ) + ) + + +def get_human_sentiment(ctx: Context, pipeline: Pipeline, negative_threshold=0.5, positive_threshold=0.333) -> str: + sentiment_probs = ( + None if ctx.validation else common_utils.get_sentiment(get_last_human_utterance(ctx, pipeline), probs=True) + ) + if sentiment_probs and isinstance(sentiment_probs, dict): + max_sentiment_prob = max(sentiment_probs.values()) + max_sentiments = [ + sentiment for sentiment in sentiment_probs if sentiment_probs[sentiment] == max_sentiment_prob + ] + if max_sentiments: + max_sentiment = max_sentiments[0] + return_negative = max_sentiment == "negative" and max_sentiment_prob >= negative_threshold + return_positive = max_sentiment == "positive" and max_sentiment_prob >= positive_threshold + if return_negative or return_positive: + return max_sentiment + return "neutral" + + +def get_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_")) -> dict: + return {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_states"].get(service_name, {}) + + +def save_cross_state(ctx: Context, _, service_name=SERVICE_NAME.replace("-", "_"), new_state={}): + if not ctx.validation: + ctx.misc["agent"]["dff_shared_state"]["cross_states"][service_name] = new_state + + +def get_cross_link(ctx: Context, pipeline: Pipeline, service_name=SERVICE_NAME.replace("-", "_")) -> dict: + links = {} if ctx.validation else ctx.misc["agent"]["dff_shared_state"]["cross_links"].get(service_name, {}) + cur_human_index = get_human_utter_index(ctx, pipeline) + cross_link = [cross_link for human_index, cross_link in links.items() if (cur_human_index - int(human_index)) == 1] + cross_link = cross_link[0] if cross_link else {} + return cross_link + + +def set_cross_link( + ctx: Context, + pipeline: Pipeline, + to_service_name, + cross_link_additional_data={}, + from_service_name=SERVICE_NAME.replace("-", "_"), +): + cur_human_index = get_human_utter_index(ctx, pipeline) + if not ctx.validation: + ctx.misc["agent"]["dff_shared_state"]["cross_links"][to_service_name] = { + cur_human_index: { + "from_service": from_service_name, + **cross_link_additional_data, + } + } + + +def reset_response_parts(ctx: Context, _): + if not ctx.validation and "response_parts" in ctx.misc["agent"]: + del ctx.misc["agent"]["response_parts"] + + +def add_parts_to_response_parts(ctx: Context, _, parts=[]): + response_parts = set([] if ctx.validation else ctx.misc["agent"].get("response_parts", [])) + response_parts.update(parts) + if not ctx.validation: + ctx.misc["agent"]["response_parts"] = sorted(list(response_parts)) + + +def set_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline): + reset_response_parts(ctx, pipeline) + add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"]) + + +def add_acknowledgement_to_response_parts(ctx: Context, pipeline: Pipeline): + if not ctx.validation and ctx.misc["agent"].get("response_parts") is None: + add_parts_to_response_parts(ctx, pipeline, parts=["body"]) + add_parts_to_response_parts(ctx, pipeline, parts=["acknowledgement"]) + + +def set_body_to_response_parts(ctx: Context, pipeline: Pipeline): + reset_response_parts(ctx, pipeline) + add_parts_to_response_parts(ctx, pipeline, parts=["body"]) + + +def add_body_to_response_parts(ctx: Context, pipeline: Pipeline): + add_parts_to_response_parts(ctx, pipeline, parts=["body"]) + + +def set_prompt_to_response_parts(ctx: Context, pipeline: Pipeline): + reset_response_parts(ctx, pipeline) + add_parts_to_response_parts(ctx, pipeline, parts=["prompt"]) + + +def add_prompt_to_response_parts(ctx: Context, pipeline: Pipeline): + add_parts_to_response_parts(ctx, pipeline, parts=["prompt"]) + + +def get_shared_memory(ctx: Context, _) -> dict: + return {} if ctx.validation else ctx.misc["agent"]["shared_memory"] + + +def get_used_links(ctx: Context, _) -> dict: + return {} if ctx.validation else ctx.misc["agent"]["used_links"] + + +def get_age_group(ctx: Context, _) -> dict: + return {} if ctx.validation else ctx.misc["agent"]["age_group"] + + +def set_age_group(ctx: Context, _, set_age_group): + if not ctx.validation: + ctx.misc["agent"]["age_group"] = set_age_group + + +def get_disliked_skills(ctx: Context, _) -> list: + return [] if ctx.validation else ctx.misc["agent"]["disliked_skills"] + + +def get_human_utter_index(ctx: Context, _) -> int: + return 0 if ctx.validation else ctx.misc["agent"]["human_utter_index"] + + +def get_previous_human_utter_index(ctx: Context, _) -> int: + return 0 if ctx.validation else ctx.misc["agent"]["previous_human_utter_index"] + + +def get_dialog(ctx: Context, _) -> dict: + return {} if ctx.validation else ctx.misc["agent"]["dialog"] + + +def get_utterances(ctx: Context, _) -> dict: + return [] if ctx.validation else ctx.misc["agent"]["dialog"]["utterances"] + + +def get_human_utterances(ctx: Context, _) -> dict: + return [] if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"] + + +def get_last_human_utterance(ctx: Context, _) -> dict: + return {"text": "", "annotations": {}} if ctx.validation else ctx.misc["agent"]["dialog"]["human_utterances"][-1] + + +def get_bot_utterances(ctx: Context, _) -> list: + return [] if ctx.validation else ctx.misc["agent"]["dialog"]["bot_utterances"] + + +def get_last_bot_utterance(ctx: Context, _) -> dict: + if not ctx.validation and ctx.misc["agent"]["dialog"]["bot_utterances"]: + return ctx.misc["agent"]["dialog"]["bot_utterances"][-1] + else: + return {"text": "", "annotations": {}} + + +def save_to_shared_memory(ctx: Context, _, **kwargs): + if not ctx.validation: + ctx.misc["agent"]["shared_memory"].update(kwargs) + + +def update_used_links(ctx: Context, _, linked_skill_name, linking_phrase): + if not ctx.validation: + agent = ctx.misc["agent"] + agent["used_links"][linked_skill_name] = agent["used_links"].get(linked_skill_name, []) + [linking_phrase] + + +def get_new_link_to(ctx: Context, pipeline: Pipeline, skill_names): + used_links = get_used_links(ctx, pipeline) + disliked_skills = get_disliked_skills(ctx, pipeline) + + link = common_link.link_to( + skill_names, human_attributes={"used_links": used_links, "disliked_skills": disliked_skills} + ) + update_used_links(ctx, pipeline, link["skill"], link["phrase"]) + return link + + +def set_dff_suspension(ctx: Context, _): + if not ctx.validation: + ctx.misc["agent"]["current_turn_dff_suspended"] = True + + +def reset_dff_suspension(ctx: Context, _): + if not ctx.validation: + ctx.misc["agent"]["current_turn_dff_suspended"] = False + + +def set_confidence(ctx: Context, pipeline: Pipeline, confidence=1.0): + if not ctx.validation: + ctx.misc["agent"]["response"].update({"confidence": confidence}) + if confidence == 0.0: + reset_can_continue(ctx, pipeline) + + +def set_can_continue(ctx: Context, _, continue_flag=common_constants.CAN_CONTINUE_SCENARIO): + if not ctx.validation: + ctx.misc["agent"]["response"].update({"can_continue": continue_flag}) + + +def reset_can_continue(ctx: Context, _): + if not ctx.validation and "can_continue" in ctx.misc["agent"]["response"]: + del ctx.misc["agent"]["response"]["can_continue"] + + +def get_named_entities_from_human_utterance(ctx: Context, pipeline: Pipeline): + # ent is a dict! ent = {"text": "London":, "type": "LOC"} + entities = common_utils.get_entities( + get_last_human_utterance(ctx, pipeline), + only_named=True, + with_labels=True, + ) + return entities + + +def get_nounphrases_from_human_utterance(ctx: Context, pipeline: Pipeline): + nps = common_utils.get_entities( + get_last_human_utterance(ctx, pipeline), + only_named=False, + with_labels=False, + ) + return nps + + +def get_fact_random_annotations_from_human_utterance(ctx: Context, pipeline: Pipeline) -> dict: + if not ctx.validation: + return ( + get_last_human_utterance(ctx, pipeline) + .get("annotations", {}) + .get("fact_random", {"facts": [], "response": ""}) + ) + else: + return {"facts": [], "response": ""} + + +def get_fact_for_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> list: + fact_random_results = get_fact_random_annotations_from_human_utterance(ctx, pipeline) + facts_for_entity = [] + for fact in fact_random_results.get("facts", []): + is_same_entity = fact.get("entity_substr", "").lower() == entity.lower() + is_sorry = "Sorry, I don't know" in fact.get("fact", "") + if is_same_entity and not is_sorry: + facts_for_entity += [fact["fact"]] + + return facts_for_entity + + +def get_news_about_particular_entity_from_human_utterance(ctx: Context, pipeline: Pipeline, entity) -> dict: + last_uttr = get_last_human_utterance(ctx, pipeline) + last_uttr_entities_news = last_uttr.get("annotations", {}).get("news_api_annotator", []) + curr_news = {} + for news_entity in last_uttr_entities_news: + if news_entity["entity"] == entity: + curr_news = news_entity["news"] + break + if not curr_news: + curr_news = common_news.get_news_about_topic(entity, NEWS_API_ANNOTATOR_URL) + + return curr_news + + +def get_facts_from_fact_retrieval(ctx: Context, pipeline: Pipeline) -> list: + annotations = get_last_human_utterance(ctx, pipeline).get("annotations", {}) + if "fact_retrieval" in annotations: + if isinstance(annotations["fact_retrieval"], dict): + return annotations["fact_retrieval"].get("facts", []) + elif isinstance(annotations["fact_retrieval"], list): + return annotations["fact_retrieval"] + return [] + + +def get_unrepeatable_index_from_rand_seq( + ctx: Context, pipeline: Pipeline, seq_name, seq_max, renew_seq_if_empty=False +) -> int: + """Return a unrepeatable index from RANDOM_SEQUENCE. + RANDOM_SEQUENCE is stored in shared merory by name `seq_name`. + RANDOM_SEQUENCE is shuffled [0..`seq_max`]. + RANDOM_SEQUENCE will be updated after index will get out of RANDOM_SEQUENCE if `renew_seq_if_empty` is True + """ + shared_memory = get_shared_memory(ctx, pipeline) + seq = shared_memory.get(seq_name, random.sample(list(range(seq_max)), seq_max)) + if renew_seq_if_empty or seq: + seq = seq if seq else random.sample(list(range(seq_max)), seq_max) + next_index = seq[-1] if seq else None + save_to_shared_memory(ctx, **{seq_name: seq[:-1]}) + return next_index + + +def get_history(ctx: Context, _): + if not ctx.validation: + return ctx.misc["agent"]["history"] + return {} + + +def get_n_last_state(ctx: Context, pipeline: Pipeline, n) -> str: + last_state = "" + history = list(get_history(ctx, pipeline).items()) + if history: + history_sorted = sorted(history, key=lambda x: x[0]) + if len(history_sorted) >= n: + last_state = history_sorted[-n][1] + return last_state + + +def get_last_state(ctx: Context, pipeline: Pipeline) -> str: + last_state = "" + history = list(get_history(ctx, pipeline).items()) + if history: + history_sorted = sorted(history, key=lambda x: x[0]) + last_state = history_sorted[-1][1] + return last_state diff --git a/common/dff_api_v1/integration/facts_utils.py b/common/dff_api_v1/integration/facts_utils.py new file mode 100644 index 0000000000..29fec8f5dc --- /dev/null +++ b/common/dff_api_v1/integration/facts_utils.py @@ -0,0 +1,316 @@ +import json +import logging +import os +import random +import re +import nltk +import requests +import sentry_sdk +import common.constants as common_constants +import common.dff_api_v1.integration.context as context +from dff.script import Context +from dff.pipeline import Pipeline + +from common.wiki_skill import ( + find_page_title, + find_all_titles, + find_paragraph, + used_types_dict, + delete_hyperlinks, + NEWS_MORE, + QUESTION_TEMPLATES, + QUESTION_TEMPLATES_SHORT, + WIKI_BADLIST, +) +from common.universal_templates import CONTINUE_PATTERN +from common.utils import is_no, is_yes + +nltk.download("punkt") + +sentry_sdk.init(os.getenv("SENTRY_DSN")) +logger = logging.getLogger(__name__) +WIKI_FACTS_URL = os.getenv("WIKI_FACTS_URL") + +with open("/src/common/wikihow_cache.json", "r") as fl: + wikihow_cache = json.load(fl) + +memory = {} + +titles_by_type = {} +for elem in used_types_dict: + types = elem.get("types", []) + titles = elem["titles"] + for tp in types: + titles_by_type[tp] = titles + +titles_by_entity_substr = {} +page_titles_by_entity_substr = {} +for elem in used_types_dict: + entity_substrings = elem.get("entity_substr", []) + titles = elem["titles"] + page_title = elem.get("page_title", "") + for substr in entity_substrings: + titles_by_entity_substr[substr] = titles + if page_title: + page_titles_by_entity_substr[substr] = page_title + +questions_by_entity_substr = {} +for elem in used_types_dict: + entity_substrings = elem.get("entity_substr", []) + question = elem.get("intro_question", "") + if question: + for substr in entity_substrings: + questions_by_entity_substr[substr] = question + +wikihowq_by_substr = {} +for elem in used_types_dict: + entity_substrings = elem.get("entity_substr", []) + wikihow_info = elem.get("wikihow_info", {}) + if wikihow_info: + for substr in entity_substrings: + wikihowq_by_substr[substr] = wikihow_info + + +def get_wikipedia_content(page_title, cache_page_dict=None): + page_content = {} + main_pages = {} + try: + if page_title: + if cache_page_dict and page_title in cache_page_dict: + page_content = cache_page_dict[page_title]["page_content"] + main_pages = cache_page_dict[page_title]["main_pages"] + else: + res = requests.post(WIKI_FACTS_URL, json={"wikipedia_titles": [[page_title]]}, timeout=1.0).json() + if res and res[0]["main_pages"] and res[0]["wikipedia_content"]: + page_content = res[0]["wikipedia_content"][0] + main_pages = res[0]["main_pages"][0] + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + + return page_content, main_pages + + +def get_wikihow_content(page_title, cache_page_dict=None): + page_content = {} + try: + if page_title: + res = requests.post(WIKI_FACTS_URL, json={"wikihow_titles": [[page_title]]}, timeout=1.0).json() + if res and res[0]["wikihow_content"]: + page_content = res[0]["wikihow_content"][0] + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + + return page_content + + +def get_titles(found_entity_substr, found_entity_types, page_content): + all_titles = find_all_titles([], page_content) + titles_we_use = [] + titles_q = {} + for tp in found_entity_types: + tp_titles = titles_by_type.get(tp, {}) + titles_we_use += list(tp_titles.keys()) + titles_q = {**titles_q, **tp_titles} + substr_titles = titles_by_entity_substr.get(found_entity_substr, {}) + titles_we_use += list(substr_titles.keys()) + titles_q = {**titles_q, **substr_titles} + return titles_q, titles_we_use, all_titles + + +def make_facts_str(paragraphs): + facts_str = "" + mentions_list = [] + mention_pages_list = [] + paragraph = "" + if paragraphs: + paragraph = paragraphs[0] + sentences = nltk.sent_tokenize(paragraph) + sentences_list = [] + cur_len = 0 + max_len = 50 + for sentence in sentences: + sanitized_sentence, mentions, mention_pages = delete_hyperlinks(sentence) + words = nltk.word_tokenize(sanitized_sentence) + if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, sanitized_sentence): + sentences_list.append(sanitized_sentence) + cur_len += len(words) + mentions_list += mentions + mention_pages_list += mention_pages + if sentences_list: + facts_str = " ".join(sentences_list) + cur_len = 0 + if sentences and not sentences_list: + sentence = sentences[0] + sanitized_sentence, mentions, mention_pages = delete_hyperlinks(sentence) + sentence_parts = sanitized_sentence.split(", ") + mentions_list += mentions + mention_pages_list += mention_pages + for part in sentence_parts: + words = nltk.word_tokenize(part) + if cur_len + len(words) < max_len and not re.findall(WIKI_BADLIST, part): + sentences_list.append(part) + cur_len += len(words) + facts_str = ", ".join(sentences_list) + if facts_str and not facts_str.endswith("."): + facts_str = f"{facts_str}." + return facts_str, mentions_list, mention_pages_list + + +def preprocess_wikihow_page(article_content): + page_content_list = [] + article_content = list(article_content.items()) + for title_num, (title, paragraphs) in enumerate(article_content): + if title != "intro": + for n, paragraph in enumerate(paragraphs): + facts_str = "" + question = "" + sentences = nltk.sent_tokenize(paragraph) + sentences_list = [] + cur_len = 0 + max_len = 50 + for sentence in sentences: + words = nltk.word_tokenize(sentence) + if cur_len + len(words) < max_len: + sentences_list.append(sentence) + cur_len += len(words) + if sentences_list: + facts_str = " ".join(sentences_list) + else: + cur_len = 0 + sentence_parts = sentences[0].split(", ") + for part in sentence_parts: + words = nltk.word_tokenize(part) + if cur_len + len(words) < max_len: + sentences_list.append(part) + cur_len += len(words) + facts_str = ", ".join(sentences_list) + + if n == len(paragraphs) - 1 and title_num != len(article_content) - 1: + next_title = article_content[title_num + 1][0] + question = f"Would you like to know about {next_title.lower()}?" + elif n != len(paragraphs) - 1: + question = random.choice(NEWS_MORE) + response_dict = {"facts_str": facts_str, "question": question} + response = f"{facts_str} {question}".strip().replace(" ", " ") + if response: + page_content_list.append(response_dict) + return page_content_list + + +def preprocess_wikipedia_page(found_entity_substr, found_entity_types, article_content, predefined_titles=None): + logger.info(f"found_entity_substr {found_entity_substr} found_entity_types {found_entity_types}") + titles_q, titles_we_use, all_titles = get_titles(found_entity_substr, found_entity_types, article_content) + if predefined_titles: + titles_we_use = predefined_titles + logger.info(f"titles_we_use {titles_we_use} all_titles {all_titles}") + facts_list = [] + for n, title in enumerate(titles_we_use): + page_title = find_page_title(all_titles, title) + paragraphs = find_paragraph(article_content, page_title) + logger.info(f"page_title {page_title} paragraphs {paragraphs[:2]}") + count_par = 0 + for num, paragraph in enumerate(paragraphs): + facts_str, *_ = make_facts_str([paragraph]) + if facts_str and facts_str.endswith(".") and len(facts_str.split()) > 4: + facts_list.append((title, facts_str)) + count_par += 1 + if count_par == 2: + break + logger.info(f"facts_list {facts_list[:3]}") + page_content_list = [] + for n, (title, facts_str) in enumerate(facts_list): + if n != len(facts_list) - 1: + next_title = facts_list[n + 1][0] + if next_title != title: + if found_entity_substr.lower() in next_title.lower(): + question_template = random.choice(QUESTION_TEMPLATES_SHORT) + question = question_template.format(next_title) + else: + question_template = random.choice(QUESTION_TEMPLATES) + question = question_template.format(next_title, found_entity_substr) + else: + question = random.choice(NEWS_MORE) + response_dict = {"facts_str": facts_str, "question": question} + response = f"{facts_str} {question}".strip().replace(" ", " ") + if response: + page_content_list.append(response_dict) + else: + page_content_list.append( + {"facts_str": facts_str, "question": f"I was very happy to tell you more about {found_entity_substr}."} + ) + logger.info(f"page_content_list {page_content_list}") + return page_content_list + + +def provide_facts_request(ctx: Context, _): + flag = False + wiki = ctx.misc.get("wiki", {}) + user_uttr: dict = ctx.misc.get("agent", {}).get("dialog", {}).get("human_utterances", [{}])[-1] + isno = is_no(user_uttr) + cur_wiki_page = wiki.get("cur_wiki_page", "") + if cur_wiki_page: + wiki_page_content_list = memory.get("wiki_page_content", []) + used_wiki_page_nums = wiki.get("used_wiki_page_nums", {}).get(cur_wiki_page, []) + logger.info(f"request, used_wiki_page_nums {used_wiki_page_nums}") + logger.info(f"request, wiki_page_content_list {wiki_page_content_list[:3]}") + if len(wiki_page_content_list) > 0 and len(used_wiki_page_nums) < len(wiki_page_content_list) and not isno: + flag = True + return flag + + +def provide_facts_response(ctx: Context, pipeline: Pipeline, page_source, wiki_page): + wiki = ctx.misc.get("wiki", {}) + user_uttr: dict = ctx.misc.get("agent", {}).get("dialog", {}).get("human_utterances", [{}])[-1] + isyes = is_yes(user_uttr) or re.findall(CONTINUE_PATTERN, user_uttr["text"]) + response = "" + cur_wiki_page = wiki.get("cur_wiki_page", "") + if not cur_wiki_page: + wiki["cur_wiki_page"] = wiki_page + if page_source == "wikiHow": + page_content = get_wikihow_content(wiki_page, wikihow_cache) + wiki_page_content_list = preprocess_wikihow_page(page_content) + memory["wiki_page_content"] = wiki_page_content_list + elif page_source == "wikipedia": + page_content = get_wikipedia_content(wiki_page) + wiki_page_content_list = preprocess_wikipedia_page(wiki_page.lower(), [], page_content) + memory["wiki_page_content"] = wiki_page_content_list + logger.info(f"wiki_page {wiki_page}") + + used_wiki_page_nums_dict = wiki.get("used_wiki_page_nums", {}) + used_wiki_page_nums = used_wiki_page_nums_dict.get(wiki_page, []) + wiki_page_content_list = memory.get("wiki_page_content", []) + logger.info(f"response, used_wiki_page_nums {used_wiki_page_nums}") + logger.info(f"response, wiki_page_content_list {wiki_page_content_list[:3]}") + + if wiki_page_content_list: + for num, fact in enumerate(wiki_page_content_list): + if num not in used_wiki_page_nums: + facts_str = fact.get("facts_str", "") + question = fact.get("question", "") + response = f"{facts_str} {question}".strip().replace(" ", " ") + used_wiki_page_nums.append(num) + used_wiki_page_nums_dict[wiki_page] = used_wiki_page_nums + wiki["used_wiki_page_nums"] = used_wiki_page_nums_dict + break + + if len(wiki_page_content_list) == len(used_wiki_page_nums): + if len(wiki_page_content_list) == len(used_wiki_page_nums): + wiki["wiki_page"] = "" + memory["wiki_page_content"] = [] + logger.info(f"response, final {response}") + if response: + if isyes: + context.set_confidence(ctx, pipeline, 1.0) + context.set_can_continue(ctx, pipeline, common_constants.MUST_CONTINUE) + else: + context.set_confidence(ctx, pipeline, 0.99) + context.set_can_continue(ctx, pipeline, common_constants.CAN_CONTINUE_SCENARIO) + else: + context.set_confidence(ctx, pipeline, 0.0) + context.set_can_continue(ctx, pipeline, common_constants.CAN_NOT_CONTINUE) + processed_node = ctx.framework_states["actor"].get("processed_node", ctx.framework_states["actor"]["next_node"]) + processed_node.response = response + ctx.framework_states["actor"]["processed_node"] = processed_node + return ctx diff --git a/common/dff_api_v1/integration/message.py b/common/dff_api_v1/integration/message.py new file mode 100644 index 0000000000..3c65b8bad6 --- /dev/null +++ b/common/dff_api_v1/integration/message.py @@ -0,0 +1,9 @@ +from typing import Optional +from dff.script import Message + + +class DreamMessage(Message): + confidence: Optional[float] = None + human_attr: Optional[dict] = None + bot_attr: Optional[dict] = None + hype_attr: Optional[dict] = None diff --git a/common/dff_api_v1/integration/processing.py b/common/dff_api_v1/integration/processing.py new file mode 100644 index 0000000000..3bff6e3297 --- /dev/null +++ b/common/dff_api_v1/integration/processing.py @@ -0,0 +1,92 @@ +import logging + +from dff.script import Context +from dff.pipeline import Pipeline +import common.constants as common_constants +from common.wiki_skill import extract_entity +from .facts_utils import provide_facts_response +from . import context + + +DREAM_SLOTS_KEY = "slots" +logger = logging.getLogger(__name__) + + +def save_slots_to_ctx(slots: dict): + def save_slots_to_ctx_processing( + ctx: Context, + _, + *args, + **kwargs, + ) -> Context: + ctx.misc[DREAM_SLOTS_KEY] = ctx.misc.get(DREAM_SLOTS_KEY, {}) | slots + return ctx + + return save_slots_to_ctx_processing + + +def entities(**kwargs): + slot_info = list(kwargs.items()) + + def extract_entities( + ctx: Context, + pipeline: Pipeline, + *args, + **kwargs, + ) -> Context: + for slot_name, slot_types in slot_info: + if isinstance(slot_types, str): + extracted_entity = extract_entity(ctx, slot_types) + if extracted_entity: + ctx = save_slots_to_ctx({slot_name: extracted_entity})(ctx, pipeline) + elif isinstance(slot_types, list): + found = False + for slot_type in slot_types: + if not slot_type.startswith("default:"): + extracted_entity = extract_entity(ctx, slot_type) + if extracted_entity: + ctx = save_slots_to_ctx({slot_name: extracted_entity})(ctx, pipeline) + found = True + if not found: + default_value = "" + for slot_type in slot_types: + if slot_type.startswith("default:"): + default_value = slot_type.split("default:")[1] + if default_value: + ctx = save_slots_to_ctx({slot_name: default_value})(ctx, pipeline) + return ctx + + return extract_entities + + +def fact_provider(page_source, wiki_page): + def response(ctx: Context, pipeline: Pipeline, *args, **kwargs): + return provide_facts_response(ctx, pipeline, page_source, wiki_page) + + return response + + +def set_confidence(confidence: float = 1.0): + def set_confidence_processing( + ctx: Context, + pipeline: Pipeline, + *args, + **kwargs, + ) -> Context: + context.set_confidence(ctx, pipeline, confidence) + return ctx + + return set_confidence_processing + + +def set_can_continue(continue_flag: str = common_constants.CAN_CONTINUE_SCENARIO): + def set_can_continue_processing( + ctx: Context, + pipeline: Pipeline, + *args, + **kwargs, + ) -> Context: + context.set_can_continue(ctx, pipeline, continue_flag) + return ctx + + return set_can_continue_processing diff --git a/common/dff_api_v1/integration/response.py b/common/dff_api_v1/integration/response.py new file mode 100644 index 0000000000..e43a78b5f6 --- /dev/null +++ b/common/dff_api_v1/integration/response.py @@ -0,0 +1,19 @@ +from dff.script import Context +from .processing import DREAM_SLOTS_KEY +from .message import DreamMessage + + +def fill_by_slots(response: DreamMessage): + def fill_responses_by_slots_inner( + ctx: Context, + _, + *args, + **kwargs, + ) -> Context: + if not response.text: + return response + for slot_name, slot_value in ctx.misc.get(DREAM_SLOTS_KEY, {}).items(): + response.text = response.text.replace("{" f"{slot_name}" "}", slot_value) + return response + + return fill_responses_by_slots_inner diff --git a/common/dff_api_v1/integration/serializer.py b/common/dff_api_v1/integration/serializer.py new file mode 100644 index 0000000000..bc4f4e5dce --- /dev/null +++ b/common/dff_api_v1/integration/serializer.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +import logging +import os +from typing import List, Any + +import sentry_sdk +from sentry_sdk.integrations.logging import ignore_logger +from dff.script import Context, MultiMessage +from dff.pipeline import Pipeline +from pydantic import BaseModel, Field, Extra, root_validator + +from common.constants import CAN_NOT_CONTINUE, CAN_CONTINUE_SCENARIO +from common.dff_api_v1.integration.context import get_last_human_utterance +from common.dff_api_v1.integration.message import DreamMessage + + +ignore_logger("root") + +sentry_sdk.init(os.getenv("SENTRY_DSN")) +SERVICE_NAME = os.getenv("SERVICE_NAME") + +logger = logging.getLogger(__name__) + + +class ExtraIgnoreModel(BaseModel): + class Config: + extra = Extra.ignore + + +class HumanAttr(ExtraIgnoreModel): + dff_shared_state: dict + used_links: dict + age_group: str + disliked_skills: list + + +class HypeAttr(ExtraIgnoreModel): + can_continue: str + + @root_validator(pre=True) + def calculate_can_continue(cls, values): + confidence = values["response"].get("confidence", 0.85) + can_continue = CAN_CONTINUE_SCENARIO if confidence else CAN_NOT_CONTINUE + values["can_continue"] = values["response"].get("can_continue", can_continue) + return values + + +class State(ExtraIgnoreModel): + context: Context + previous_human_utter_index: int = -1 + current_turn_dff_suspended: bool = False + history: dict = Field(default_factory=dict) + shared_memory: dict = Field(default_factory=dict) + + @root_validator(pre=True) + def shift_state_history(cls, values): + values["previous_human_utter_index"] = values["human_utter_index"] + values["history"][str(values["human_utter_index"])] = list(values["context"].labels.values())[-1] + return values + + @root_validator(pre=True) + def validate_context(cls, values): + context = values["context"] + context.clear(2, ["requests", "responses", "labels"]) + del context.misc["agent"] + return values + + +class Agent(ExtraIgnoreModel): + previous_human_utter_index: int = -1 + human_utter_index: int + dialog: Any + entities: dict = Field(default_factory=dict) + shared_memory: dict = Field(default_factory=dict) + current_turn_dff_suspended: bool = False + previous_turn_dff_suspended: bool = False + response: dict = Field(default_factory=dict) + dff_shared_state: dict = Field(default_factory=dict) + cache: dict = Field(default_factory=dict) + history: dict = Field(default_factory=dict) + used_links: dict = Field(default_factory=dict) + age_group: str = "" + disliked_skills: list = Field(default_factory=list) + clarification_request_flag: bool = False + + @root_validator(pre=True) + def get_state_props(cls, values): + state = values.get("state", {}) + values = values | state + return values + + +def load_ctxs(requested_data) -> List[Context]: + dialog_batch = requested_data.get("dialog_batch", []) + human_utter_index_batch = requested_data.get("human_utter_index_batch", [0] * len(dialog_batch)) + state_batch = requested_data.get(f"{SERVICE_NAME}_state_batch", [{}] * len(dialog_batch)) + dff_shared_state_batch = requested_data.get("dff_shared_state_batch", [{}] * len(dialog_batch)) + entities_batch = requested_data.get("entities_batch", [{}] * len(dialog_batch)) + used_links_batch = requested_data.get("used_links_batch", [{}] * len(dialog_batch)) + age_group_batch = requested_data.get("age_group_batch", [""] * len(dialog_batch)) + disliked_skills_batch = requested_data.get("disliked_skills_batch", [{}] * len(dialog_batch)) + clarification_request_flag_batch = requested_data.get( + "clarification_request_flag_batch", + [False] * len(dialog_batch), + ) + ctxs = [] + for ( + human_utter_index, + dialog, + state, + dff_shared_state, + entities, + used_links, + age_group, + disliked_skills, + clarification_request_flag, + ) in zip( + human_utter_index_batch, + dialog_batch, + state_batch, + dff_shared_state_batch, + entities_batch, + used_links_batch, + age_group_batch, + disliked_skills_batch, + clarification_request_flag_batch, + ): + ctx = Context.cast(state.get("context", {})) + agent = Agent( + human_utter_index=human_utter_index, + dialog=dialog, + state=state, + dff_shared_state=dff_shared_state, + entities=entities, + used_links=used_links, + age_group=age_group, + disliked_skills=disliked_skills, + clarification_request_flag=clarification_request_flag, + ) + ctx.misc["agent"] = agent.dict() + ctxs += [ctx] + return ctxs + + +def get_response(ctx: Context, _): + agent = ctx.misc["agent"] + response_parts = agent.get("response_parts", []) + confidence = agent["response"].get("confidence", 0.85) + state = State(context=ctx, **agent).dict(exclude_none=True) + human_attr = HumanAttr.parse_obj(agent).dict() | {f"{SERVICE_NAME}_state": state} + hype_attr = HypeAttr.parse_obj(agent).dict() | ({"response_parts": response_parts} if response_parts else {}) + response = ctx.last_response + if isinstance(response, MultiMessage): + responses = [] + message: dict + for message in response.messages: + reply = message.text or "" + conf = message.confidence or confidence + h_a = human_attr | (message.human_attr or {}) + attr = hype_attr | (message.hype_attr or {}) + b_a = message.bot_attr or {} + responses += [(reply, conf, h_a, b_a, attr)] + return list(zip(*responses)) + else: + return (response.text, confidence, human_attr, {}, hype_attr) + + +def run_dff(ctx: Context, pipeline: Pipeline): + last_request = get_last_human_utterance(ctx, pipeline.actor)["text"] + pipeline.context_storage[ctx.id] = ctx + ctx = pipeline(DreamMessage(text=last_request), ctx.id) + response = get_response(ctx, pipeline.actor) + del pipeline.context_storage[ctx.id] + return response diff --git a/common/dff_api_v1/requirements.txt b/common/dff_api_v1/requirements.txt new file mode 100644 index 0000000000..36c99b7133 --- /dev/null +++ b/common/dff_api_v1/requirements.txt @@ -0,0 +1,13 @@ +# base +six==1.16.0 +sentry_sdk[flask]==0.14.3 +flask==1.1.1 +itsdangerous==2.0.1 +gunicorn==19.9.0 +healthcheck==1.3.3 +# bound to a specific commit temporarily +dff==0.4.1 +# test +requests==2.22.0 +jinja2<=3.0.3 +Werkzeug<=2.0.3 diff --git a/skills/dff_science_skill/tests/link_to_in.json b/skills/dff_science_skill/tests/link_to_in.json index fc68e4ddda..e37a6f880e 100644 --- a/skills/dff_science_skill/tests/link_to_in.json +++ b/skills/dff_science_skill/tests/link_to_in.json @@ -550,6 +550,9 @@ "disliked_skills_batch": [ [] ], + "prompts_goals_batch": [ + {} + ], "clarification_request_flag_batch": [ false ], diff --git a/skills/dff_template_skill/Dockerfile b/skills/dff_template_skill/Dockerfile index 998fc6c9f8..10b259dc38 100644 --- a/skills/dff_template_skill/Dockerfile +++ b/skills/dff_template_skill/Dockerfile @@ -3,7 +3,7 @@ FROM python:3.9.1 # Do not change anything in this section WORKDIR /src -COPY common/dff/requirements.txt . +COPY common/dff_api_v1/requirements.txt . RUN pip install -r requirements.txt # ###################### CUSTOM SECTION ###################################### diff --git a/skills/dff_template_skill/scenario/condition.py b/skills/dff_template_skill/scenario/condition.py index 7ef3b07924..2c266b6abd 100644 --- a/skills/dff_template_skill/scenario/condition.py +++ b/skills/dff_template_skill/scenario/condition.py @@ -1,15 +1,16 @@ import logging -from df_engine.core import Context, Actor +from dff.script import Context +from dff.pipeline import Pipeline -from common.dff.integration import condition as int_cnd +from common.dff_api_v1.integration import condition as int_cnd logger = logging.getLogger(__name__) # .... def example_lets_talk_about(): - def example_lets_talk_about_handler(ctx: Context, actor: Actor, *args, **kwargs) -> str: - return int_cnd.is_lets_chat_about_topic_human_initiative(ctx, actor) + def example_lets_talk_about_handler(ctx: Context, pipeline: Pipeline) -> str: + return int_cnd.is_lets_chat_about_topic_human_initiative(ctx, pipeline) return example_lets_talk_about_handler diff --git a/skills/dff_template_skill/scenario/main.py b/skills/dff_template_skill/scenario/main.py index 8be7bc6ec7..c831ad4b22 100644 --- a/skills/dff_template_skill/scenario/main.py +++ b/skills/dff_template_skill/scenario/main.py @@ -1,14 +1,24 @@ import logging import re -from df_engine.core.keywords import LOCAL, PROCESSING, TRANSITIONS, RESPONSE, GLOBAL -from df_engine.core import Actor -import df_engine.conditions as cnd -import df_engine.labels as lbl -import common.dff.integration.condition as int_cnd -import common.dff.integration.processing as int_prs -import common.dff.integration.response as int_rsp +import dff.script.conditions as cnd +import dff.script.labels as lbl +from dff.script import ( + MultiMessage, + LOCAL, + PRE_TRANSITIONS_PROCESSING, + PRE_RESPONSE_PROCESSING, + TRANSITIONS, + RESPONSE, + GLOBAL, +) +from dff.pipeline import Pipeline + +import common.dff_api_v1.integration.condition as int_cnd +import common.dff_api_v1.integration.processing as int_prs +import common.dff_api_v1.integration.response as int_rsp +from common.dff_api_v1.integration.message import DreamMessage import common.constants as common_constants @@ -19,36 +29,38 @@ logger = logging.getLogger(__name__) # First of all, to create a dialog agent, we need to create a dialog script. -# Below, `flows` is the dialog script. -# A dialog script is a flow dictionary that can contain multiple flows . -# Flows are needed in order to divide a dialog into sub-dialogs and process them separately. +# Below, we lay out our script and assign it to the `script` variable. +# A dialog script is a dictionary each item of which corresponds to a different namespace, aka flow. +# Flows allow you to divide a dialog into sub-dialogs and process them separately. # For example, the separation can be tied to the topic of the dialog. -# In our example, there is one flow called greeting_flow. - -# Inside each flow, we can describe a sub-dialog using keyword `GRAPH` from df_engine.core.keywords module. -# Here we can also use keyword `GLOBAL_TRANSITIONS`, which we have considered in other examples. +# In our example, there is one flow called greeting. -# `GRAPH` describes a sub-dialog using linked nodes, each node has the keywords `RESPONSE` and `TRANSITIONS`. +# A flow describes a sub-dialog using linked nodes, each node has the keywords `RESPONSE` and `TRANSITIONS`. -# `RESPONSE` - contains the response that the dialog agent will return when transitioning to this node. +# `RESPONSE` - contains the response that the dialog agent will return when visiting this node. # `TRANSITIONS` - describes transitions from the current node to other nodes. # `TRANSITIONS` are described in pairs: -# - the node to which the agent will perform the transition +# - the node to visit # - the condition under which to make the transition -flows = { +script = { GLOBAL: { TRANSITIONS: { ("greeting", "node1"): cnd.regexp(r"\bhi\b"), }, }, - "sevice": { + "service": { "start": { - RESPONSE: "", + RESPONSE: DreamMessage(text=""), + # We simulate extraction of two slots at the dialog start. + # Slot values are then used in the dialog. + PRE_TRANSITIONS_PROCESSING: { + "save_slots_to_ctx": int_prs.save_slots_to_ctx({"topic": "science", "user_name": "Gordon Freeman"}) + }, TRANSITIONS: {("greeting", "node1"): cnd.true()}, }, "fallback": { - RESPONSE: "Ooops", + RESPONSE: DreamMessage(text="Ooops"), TRANSITIONS: { lbl.previous(): cnd.regexp(r"previous", re.IGNORECASE), lbl.repeat(0.2): cnd.true(), @@ -57,62 +69,71 @@ }, "greeting": { LOCAL: { - PROCESSING: { + PRE_RESPONSE_PROCESSING: { "set_confidence": int_prs.set_confidence(1.0), "set_can_continue": int_prs.set_can_continue(), - # "fill_responses_by_slots": int_prs.fill_responses_by_slots(), }, }, "node1": { - RESPONSE: int_rsp.multi_response(replies=["Hi, how are you?", "Hi, what's up?"]), # several hypothesis - PROCESSING: { - "save_slots_to_ctx": int_prs.save_slots_to_ctx({"topic": "science", "user_name": "Gordon Freeman"}) - }, + RESPONSE: MultiMessage( + messages=[DreamMessage(text="Hi, how are you?"), DreamMessage(text="Hi, what's up?")] + ), # several hypotheses TRANSITIONS: {"node2": cnd.regexp(r"how are you", re.IGNORECASE)}, }, "node2": { - RESPONSE: loc_rsp.example_response("Good. What do you want to talk about?"), - # loc_rsp.example_response is just for an example, you can use just str without example_response func + # The response section can contain any function that returns a DreamMessage. + RESPONSE: loc_rsp.example_response(DreamMessage(text="Good. What do you want to talk about?")), TRANSITIONS: {"node3": loc_cnd.example_lets_talk_about()}, }, "node3": { - RESPONSE: "Sorry, I can not talk about that now. Maybe late. Do you like {topic}?", - PROCESSING: {"fill_responses_by_slots": int_prs.fill_responses_by_slots()}, + RESPONSE: int_rsp.fill_by_slots( + DreamMessage(text="Sorry, I can not talk about that now. Maybe later. Do you like {topic}?") + ), TRANSITIONS: { "node4": int_cnd.is_yes_vars, "node5": int_cnd.is_no_vars, "node6": int_cnd.is_do_not_know_vars, - "node7": cnd.true(), # it will be chosen if other conditions are False + "node7": cnd.true(), # this option will be chosen if no other condition is met. }, }, "node4": { - RESPONSE: "I like {topic} too, {user_name}", - PROCESSING: {"fill_responses_by_slots": int_prs.fill_responses_by_slots()}, + # Invoke a special function to insert slots to the response. + RESPONSE: int_rsp.fill_by_slots(DreamMessage(text="I like {topic} too, {user_name}")), TRANSITIONS: {("node7", 0.1): cnd.true()}, }, "node5": { - RESPONSE: "I do not like {topic} too, {user_name}", - PROCESSING: {"fill_responses_by_slots": int_prs.fill_responses_by_slots()}, + RESPONSE: int_rsp.fill_by_slots(DreamMessage(text="I do not like {topic} too, {user_name}")), TRANSITIONS: {("node7", 0.1): cnd.true()}, }, "node6": { - RESPONSE: "I have no opinion about {topic} too, {user_name}", - PROCESSING: {"fill_responses_by_slots": int_prs.fill_responses_by_slots()}, + RESPONSE: int_rsp.fill_by_slots(DreamMessage(text="I have no opinion about {topic} too, {user_name}")), TRANSITIONS: {("node7", 0.1): cnd.true()}, }, "node7": { - RESPONSE: int_rsp.multi_response( - replies=["bye", "goodbye"], - confidences=[1.0, 0.5], - hype_attr=[ - {"can_continue": common_constants.MUST_CONTINUE}, # for the first hyp - {"can_continue": common_constants.CAN_CONTINUE_SCENARIO}, # for the second hyp - ], + RESPONSE: MultiMessage( + # DreamMessage attributes, like confidence and can_continue status + # can be used to override the parameters extracted by Dream engine. + messages=[ + DreamMessage( + text="bye", confidence=1.0, hype_attr={"can_continue": common_constants.MUST_CONTINUE} + ), + DreamMessage( + text="goodbye", + confidence=0.5, + hype_attr={"can_continue": common_constants.CAN_CONTINUE_SCENARIO}, + ), + ] ), - PROCESSING: {"set_confidence": int_prs.set_confidence(0.0)}, + PRE_RESPONSE_PROCESSING: {"set_confidence": int_prs.set_confidence(0.0)}, }, }, } -actor = Actor(flows, start_label=("sevice", "start"), fallback_label=("sevice", "fallback")) +db = dict() +pipeline = Pipeline.from_script( + script=script, + start_label=("service", "start"), + fallback_label=("service", "fallback"), + context_storage=db, +) diff --git a/skills/dff_template_skill/scenario/response.py b/skills/dff_template_skill/scenario/response.py index 50a849c902..7c2551a5e8 100644 --- a/skills/dff_template_skill/scenario/response.py +++ b/skills/dff_template_skill/scenario/response.py @@ -1,6 +1,6 @@ import logging -from df_engine.core import Context, Actor +from dff.script import Context logger = logging.getLogger(__name__) @@ -8,7 +8,7 @@ def example_response(reply: str): - def example_response_handler(ctx: Context, actor: Actor, *args, **kwargs) -> str: + def example_response_handler(ctx: Context, _) -> str: return reply return example_response_handler diff --git a/skills/dff_template_skill/server.py b/skills/dff_template_skill/server.py index 72e91e68c5..8c7fbddad7 100644 --- a/skills/dff_template_skill/server.py +++ b/skills/dff_template_skill/server.py @@ -4,16 +4,14 @@ import time import os import random +from common.dff_api_v1.integration.serializer import load_ctxs, run_dff from flask import Flask, request, jsonify from healthcheck import HealthCheck import sentry_sdk from sentry_sdk.integrations.logging import ignore_logger - -from common.dff.integration.actor import load_ctxs, get_response - -from scenario.main import actor +from scenario.main import pipeline # import test_server @@ -45,8 +43,7 @@ def handler(requested_data, random_seed=None): # for tests if random_seed: random.seed(int(random_seed)) - ctx = actor(ctx) - responses.append(get_response(ctx, actor)) + responses.append(run_dff(ctx, pipeline)) except Exception as exc: sentry_sdk.capture_exception(exc) logger.exception(exc) diff --git a/skills/dff_template_skill/tests/lets_talk_out.json b/skills/dff_template_skill/tests/lets_talk_out.json index 229271fb1e..83ff55d7d0 100644 --- a/skills/dff_template_skill/tests/lets_talk_out.json +++ b/skills/dff_template_skill/tests/lets_talk_out.json @@ -29,25 +29,21 @@ ] }, "requests": { - "0": "hi." + "0": { + "text": "hi." + } }, "responses": { - "0": [ - [ - "Hi, how are you?", - 0.0, - {}, - {}, - {} - ], - [ - "Hi, what's up?", - 0.0, - {}, - {}, - {} + "0": { + "messages": [ + { + "text": "Hi, how are you?" + }, + { + "text": "Hi, what's up?" + } ] - ] + } }, "misc": { "slots": { @@ -56,7 +52,7 @@ } }, "validation": false, - "actor_state": {} + "framework_states": {} } }, "dff_shared_state": { @@ -88,25 +84,21 @@ ] }, "requests": { - "0": "hi." + "0": { + "text": "hi." + } }, "responses": { - "0": [ - [ - "Hi, how are you?", - 0.0, - {}, - {}, - {} - ], - [ - "Hi, what's up?", - 0.0, - {}, - {}, - {} + "0": { + "messages": [ + { + "text": "Hi, how are you?" + }, + { + "text": "Hi, what's up?" + } ] - ] + } }, "misc": { "slots": { @@ -115,7 +107,7 @@ } }, "validation": false, - "actor_state": {} + "framework_states": {} } }, "dff_shared_state": {