Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/refactor transformers lm #317

Merged
merged 19 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions assistant_dists/dream_persona_prompted/README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Dream Prompted Distribution

**_One may consider this distribution as a TEMPLATE for a prompt-based distribution which may contain any number of prompt-based skills each of which is conditioned on a single prompt during the whole conversation_**
**_One may consider this distribution as a TEMPLATE for a prompt-based distribution which may contain any number of
prompt-based skills each of which is conditioned on a single prompt during the whole conversation_**

Each Prompt-based Skill utilizes the **same prompt during the whole dialog**!
**Note!** Each Prompt-based Skill utilizes the **same prompt during the whole dialog**!

# What is Dream Prompted distribution
# What is Dream Prompted Distribution

Dream Prompted distribution is an example of the prompt-based dialogue system which contains one prompt-based skill, in particular, prompt is a persona description.
Dream Prompted distribution is an example of the prompt-based dialogue system which contains one prompt-based skill,
in particular, prompt is a persona description.

Dream Prompted distribution contains the following skills:
* Dummy Skill (`dummy_skill`) is a fallback skill (also it is a part of agent container, so no separate container required)
* DFF Dream Persona Prompted Skill (`dff_dream_persona_prompted_skill`) is a skill created via DFF (Dialog Flow Framework)
which generates a response to the current dialogue context taking into account the given prompt
(the **prompt is the same for all the dialogue steps**).
which generates a response to the current dialogue context taking into account the given prompt, i.g., bot's persona description.

### DFF Dream Persona Prompted Skill

The **DFF Dream Persona Prompted Skill** is a light-weight container sending requests to the generative service
which utilizes a neural network for prompt-based generation.
DFF Dream Persona Prompted Skill accepts two main environmental variables:
Expand All @@ -23,26 +25,44 @@ DFF Dream Persona Prompted Skill accepts two main environmental variables:
The service must utilize the same input-output format as Transformers-LM (`transformers_lm`).
* `N_UTTERANCES_CONTEXT` contains lengths of the considered context in terms of number of dialogue utterances.

**Note!** DFF Dream Persona Prompted Skill utilizes a special universal template `skills/dff_template_prompted_skill`
which do not require creation of the new skill's directory. For your convenience, creating a new skill,
you should utilize the same template folder but specify another prompt file, service port, and specify another container name.

### Prompt Selector
The distribution may contain several Prompt-based skills. Therefore, the **Prompt Selector** component is presented.

The distribution may contain **several Prompt-based skills.** Therefore, the **Prompt Selector** component is presented.
The Prompt Selector is also a light-weight container utilizing **Sentence Ranker** component
(its URL is given in `.env` file as `SENTENCE_RANKER_SERVICE_URL`) to select `N_SENTENCES_TO_RETURN`
the most relevant prompts (precisely, it returns ordered list of prompt names) among the given ones.
The `,`-joint list of the prompt names to be considered is given as an environmental variable `PROMPTS_TO_CONSIDER`.
Each considered prompt should be located as `dream/common/prompts/<prompt_name>.json`.

**Note!** In the Dream Persona Prompted Distribution we give a list of prompts to the Prompt Selector: `dream_persona,pizza`
separated with semicolon just for the demonstration of the `PROMPTS_TO_CONSIDER`'s input format. Actually,
Dream Persona Prompted Distribution contains only one prompted skill which utilizes Dream Persona prompt.

### Skill Selector

**Important!** If Prompt Selector annotations are detected in the user utterance,
You should not do any changes in the Skill Selector, it would call all the skills with the most relevant prompts
automatically according to the Prompt Selector. If Prompt Selector annotations are detected in the user utterance,
the Skill Selector turns on skills with names `dff_<prompt_name>_prompted_skill` for each prompt_name from
`N_SENTENCES_TO_RETURN` the most relevant prompts detected by Prompt Selector.
Therefore, a prompt name can contain `'_'` but not `'-'`.
`N_SENTENCES_TO_RETURN` the most relevant prompts detected by Prompt Selector.
Therefore, a prompt name can contain `'_'` but not `'-'`.

**Note!** Pay attention that you may specify to the Prompt Selector prompt names
even if the corresponding skills are not presented in the distribution, so if you, for example, specify 5 prompt names
while your distribution contains only 2 prompted skill, and you assign the number of returned most relevant prompts
(`N_SENTENCES_TO_RETURN`) to 3, you may face a situation when the Prompt Selector will choose all prompts for which
you do not have skills, so the response on that step will be provided by other skills presented in the distribution
(in particular, by Dummy Skill for the current version of Dream Prompted distribution).

# How to Create a New Prompted Distribution

If one wants to create a new prompted distribution (distribution containing prompt-based skill(s)), one should:

1. Copy the `dream/assistant_dists/dream_persona_prompted` directory to `dream/assistant_dists/dream_custom_prompted`
(this name is an example!).
(the name is an example!).
2. **For each prompt-based skill, one needs to**:
1. create a `dream/common/prompts/<prompt_name>.json` files containing a prompt.
**Important!** `<prompt_name>` should only contain letters, numbers and underscores (`_`) but no dashes (`-`)!
Expand Down Expand Up @@ -97,7 +117,7 @@ If one wants to create a new prompted distribution (distribution containing prom
6. If one does not want to keep DFF Dream Persona Prompted Skill in their distribution, one should remove all mentions
of DFF Dream Persona Prompted Skill container from `yml`-configs and `pipeline_conf.json` files.

**Important!** Please, take into account that naming skill utilizing <prompt_name> according to the instruction above
**Note!** Please, take into account that naming skill utilizing <prompt_name> according to the instruction above
is very important to provide Skill Selector automatically turn on the prompt-based skills which are returned as
`N_SENTENCES_TO_RETURN` the most relevant prompts.

Expand Down
2 changes: 1 addition & 1 deletion assistant_dists/dream_russian/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ services:
spelling-preprocessing:8074, entity-linking:8075, wiki-parser:8077, dff-generative-skill:8092,
dff-friendship-skill:8086, entity-detection:8103, dialogpt:8091,
dff-template-skill:8120, spacy-annotator:8125, dialogrpt:8122, toxic-classification:8126"
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480}
WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-600}
HIGH_PRIORITY_INTENTS: 1
RESTRICTION_FOR_SENSITIVE_CASE: 1
ALWAYS_TURN_ON_ALL_SKILLS: 0
Expand Down
5 changes: 2 additions & 3 deletions common/containers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import requests


def is_container_running(model_url, timeout=4):
def is_container_running(model_url, json_data, timeout=4):
try:
requested_data = [{"speaker": "human", "text": "hi"}]
response = requests.post(model_url, json={"dialog_contexts": [requested_data]}, timeout=timeout)
response = requests.post(model_url, json=json_data, timeout=timeout)
if response.status_code == 200:
return True
except Exception as exc:
Expand Down
2 changes: 1 addition & 1 deletion services/dialogpt_RU/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ def respond():
total_time = time.time() - st_time
logger.info(f"dialogpt exec time: {total_time:.3f}s")

return jsonify({"generated_responses": batch_generated_responses})
return jsonify(batch_generated_responses)
2 changes: 1 addition & 1 deletion services/dialogpt_RU/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_respond():
]

request_data = {"dialog_contexts": dialog_contexts, "num_return_sequences": 5}
result = requests.post(url, json=request_data).json()["generated_responses"][0]
result = requests.post(url, json=request_data).json()[0]

assert len(result) == 5 and len(result[0]) > 0, f"Got\n{result}"
print("Success!")
Expand Down
2 changes: 1 addition & 1 deletion services/infilling/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.0.1
transformers==4.6.0
sentencepiece==0.1.94
flask==1.1.1
itsdangerous==2.0.1
Expand Down
2 changes: 1 addition & 1 deletion services/masked_lm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.0.1
transformers==4.6.0
sentencepiece==0.1.94
flask==1.1.1
itsdangerous==2.0.1
Expand Down
2 changes: 1 addition & 1 deletion services/question_generator/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.0.1
transformers==4.6.0
sentencepiece==0.1.94
flask==1.1.1
itsdangerous==2.0.1
Expand Down
2 changes: 1 addition & 1 deletion services/sentence_ranker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
transformers==4.0.1
transformers==4.6.0
sentencepiece==0.1.94
flask==1.1.1
gunicorn==19.9.0
Expand Down
2 changes: 1 addition & 1 deletion services/transformers_lm/gpt_j_6b.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
"top_p": 0.9,
"temperature": 0.9,
"do_sample": true,
"num_return_sequences": 1
"num_return_sequences": 3
}
52 changes: 29 additions & 23 deletions services/transformers_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
CONFIG_NAME = os.environ.get("CONFIG_NAME")
HALF_PRECISION = bool(os.environ.get("HALF_PRECISION", 0))
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
ZERO_CONFIDENCE = 0.0
NAMING = ["AI", "Human"]

with open(CONFIG_NAME, "r") as f:
generation_params = json.load(f)
max_length = generation_params.get("max_length", 50)
Expand All @@ -30,10 +30,19 @@
logging.getLogger("werkzeug").setLevel("WARNING")


def generate_responses(instruction, context, model, tokenizer, continue_last_uttr=False):
def generate_responses(context, model, tokenizer, prompt, continue_last_uttr=False):
outputs = []
dialog_context = instruction + "\n" + "\n".join(context) + "\n" + "AI:"
logger.info(f"context inside generate_responses seen as: {[dialog_context]}")
dialog_context = ""
if prompt:
dialog_context += prompt + "\n"
s = len(context) % 2
context = [f"{NAMING[(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]
if continue_last_uttr:
dialog_context += "\n".join(context)
else:
dialog_context += "\n".join(context) + f"\n{NAMING[0]}:"

logger.info(f"context inside generate_responses seen as: {dialog_context}")
bot_input_ids = tokenizer([dialog_context], return_tensors="pt").input_ids
with torch.no_grad():
if torch.cuda.is_available():
Expand All @@ -48,8 +57,8 @@ def generate_responses(instruction, context, model, tokenizer, continue_last_utt
chat_history_ids = chat_history_ids.cpu()
for result in chat_history_ids:
output = tokenizer.decode(result, skip_special_tokens=True)
logger.info(f"full output: {[output]}")
result_cut = output.replace(dialog_context + " ", "").split("\n")[0]
logger.info(f"hypothesis: {result_cut}")
outputs.append(result_cut)
return outputs

Expand All @@ -64,11 +73,7 @@ def generate_responses(instruction, context, model, tokenizer, continue_last_utt
model.to("cuda")
logger.info("transformers_lm is set to run on cuda")
example_response = generate_responses(
"",
["Question: What is the goal of SpaceX? Answer: To revolutionize space transportation. "],
model,
tokenizer,
continue_last_uttr=False,
["What is the goal of SpaceX?"], model, tokenizer, "You are a SpaceX Assistant."
)
logger.info(f"example response: {example_response}")
logger.info("transformers_lm is ready")
Expand All @@ -82,27 +87,28 @@ def generate_responses(instruction, context, model, tokenizer, continue_last_utt
def respond():
st_time = time.time()
contexts = request.json.get("dialog_contexts", [])
prompts = request.json.get("prompts", [])
if len(contexts) > 0 and len(prompts) == 0:
prompts = [""] * len(contexts)

try:
responses = []
confidences = []
for context in contexts:
outputs = generate_responses("", context, model, tokenizer)
logger.info(f"outputs: {outputs}")
for context, prompt in zip(contexts, prompts):
curr_responses = []
outputs = generate_responses(context, model, tokenizer, prompt)
for response in outputs:
if len(response) >= 3:
# drop too short responses
responses += [response]
confidences += [DEFAULT_CONFIDENCE]
if len(response) >= 2:
curr_responses += [response]
else:
responses += [""]
confidences += [ZERO_CONFIDENCE]
curr_responses += [""]
responses += [curr_responses]

except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
responses = [[""]] * len(contexts)
confidences = [[ZERO_CONFIDENCE]] * len(contexts)

logger.info(f"transformers_lm output: {responses}")
total_time = time.time() - st_time
logger.info(f"transformers_lm exec time: {total_time:.3f}s")
return jsonify(list(zip(responses, confidences)))
return jsonify(responses)
18 changes: 10 additions & 8 deletions services/transformers_lm/test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import os
import requests


N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1))


def test_respond():
url = "http://0.0.0.0:8130/respond"
contexts = [
[
"Respond like a friendly chatbot",
"Human: Hi! I am Marcus. How are you today?",
]
"Hi! I am Marcus. How are you today?",
"Hi Marcus! I am fine. How are you?",
"I am great. What are your plans for today?",
],
["Hi Marcus! I am fine. How are you?", "I am great. What are your plans for today?"],
]
prompts = [
"Respond like a friendly chatbot.",
"Respond like a friendly chatbot.",
]
result = requests.post(url, json={"dialog_contexts": contexts}).json()
result = requests.post(url, json={"dialog_contexts": contexts, "prompts": prompts}).json()
print(result)
assert [all(len(sample[0]) > 0 for sample in result)], f"Got\n{result}\n, something is wrong"
print("Success")
Expand Down
40 changes: 30 additions & 10 deletions skills/dff_generative_skill/scenario/response.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
import requests
import sentry_sdk
from os import getenv
Expand All @@ -17,7 +18,13 @@
assert DIALOGPT_SERVICE_URL


def compose_data_for_dialogpt(ctx, actor):
FIX_PUNCTUATION = re.compile(r"\s(?=[\.,:;])")
GENERATIVE_TIMEOUT = 4
DEFAULT_CONFIDENCE = 0.9
LOW_CONFIDENCE = 0.5


def compose_data_for_model(ctx, actor):
data = []
# for uttr in dialog["utterances"][-3:]:
# curr_uttr = {"speaker": uttr["user"]["user_type"], "text": uttr["text"]}
Expand All @@ -38,7 +45,13 @@ def compose_data_for_dialogpt(ctx, actor):


def generative_response(ctx: Context, actor: Actor, *args, **kwargs) -> Any:
curr_responses, curr_confidences, curr_human_attrs, curr_bot_attrs, curr_attrs = [], [], [], [], []
curr_responses, curr_confidences, curr_human_attrs, curr_bot_attrs, curr_attrs = (
[],
[],
[],
[],
[],
)

def gathering_responses(reply, confidence, human_attr, bot_attr, attr):
nonlocal curr_responses, curr_confidences, curr_human_attrs, curr_bot_attrs, curr_attrs
Expand All @@ -48,19 +61,26 @@ def gathering_responses(reply, confidence, human_attr, bot_attr, attr):
curr_human_attrs += [human_attr]
curr_bot_attrs += [bot_attr]
curr_attrs += [attr]
logger.info(f"dff-generative-skill: {reply}")

request_data = compose_data_for_dialogpt(ctx, actor)
request_data = compose_data_for_model(ctx, actor)
logger.info(f"request_data: {request_data}")
if len(request_data) > 0:
response = requests.post(DIALOGPT_SERVICE_URL, json={"dialog_contexts": [request_data]}, timeout=3.8)
hypotheses = response.json()["generated_responses"][0]
response = requests.post(
DIALOGPT_SERVICE_URL,
json={"dialog_contexts": [request_data]},
timeout=3.8,
)
hypotheses = response.json()[0]
else:
hypotheses = []

logger.info(f"hyps: {hypotheses}")
for hyp in hypotheses:
if hyp[-1] not in [".", "?", "!"]:
hyp += "."
gathering_responses(hyp, 0.99, {}, {}, {"can_continue": CAN_NOT_CONTINUE})
confidence = DEFAULT_CONFIDENCE
hyp_text = " ".join(hyp.split())
if len(hyp_text) and hyp_text[-1] not in [".", "?", "!"]:
hyp_text += "."
confidence = LOW_CONFIDENCE
gathering_responses(hyp_text, confidence, {}, {}, {"can_continue": CAN_NOT_CONTINUE})

if len(curr_responses) == 0:
return ""
Expand Down
4 changes: 3 additions & 1 deletion skills/dff_generative_skill/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def handler(requested_data, random_seed=None):


while True:
result = containers.is_container_running(DIALOGPT_SERVICE_URL)
result = containers.is_container_running(
DIALOGPT_SERVICE_URL, {"dialog_contexts": [[{"speaker": "human", "text": "hi"}]]}
)
if result:
break
else:
Expand Down
4 changes: 0 additions & 4 deletions skills/dff_template_prompted_skill/scenario/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
"generation": {
"start_node": {
RESPONSE: "",
TRANSITIONS: {"greeting": cnd.true()},
},
"greeting": {
RESPONSE: loc_rsp.generative_response,
TRANSITIONS: {"generative_response_node": cnd.true()},
},
"generative_response_node": {
Expand Down
Loading