Skip to content

Commit

Permalink
Feat/update dialogpt (#170)
Browse files Browse the repository at this point in the history
* feat: update dialogpt

* fix: codestyle

* fix: book skill false start
  • Loading branch information
dilyararimovna authored Jun 16, 2022
1 parent 726d80c commit 149be3b
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
3 changes: 2 additions & 1 deletion assistant_dists/dream/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,8 @@ services:
args:
SERVICE_PORT: 8125
SERVICE_NAME: dialogpt
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-small
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-medium
N_HYPOTHESES_TO_GENERATE: 5
context: ./services/dialogpt/
command: flask run -h 0.0.0.0 -p 8125
environment:
Expand Down
3 changes: 2 additions & 1 deletion assistant_dists/dream_mini/docker-compose.override.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ services:
args:
SERVICE_PORT: 8125
SERVICE_NAME: dialogpt
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-small
PRETRAINED_MODEL_NAME_OR_PATH: microsoft/DialoGPT-medium
N_HYPOTHESES_TO_GENERATE: 5
context: ./services/dialogpt/
command: flask run -h 0.0.0.0 -p 8125
environment:
Expand Down
3 changes: 3 additions & 0 deletions services/dialogpt/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ ARG PRETRAINED_MODEL_NAME_OR_PATH
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH}
ARG SERVICE_PORT
ENV SERVICE_PORT ${SERVICE_PORT}
ARG N_HYPOTHESES_TO_GENERATE
ENV N_HYPOTHESES_TO_GENERATE ${N_HYPOTHESES_TO_GENERATE}


COPY ./requirements.txt /src/requirements.txt
RUN pip install -r /src/requirements.txt
Expand Down
30 changes: 19 additions & 11 deletions services/dialogpt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
logging.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
DEFAULT_CONFIDENCE = 0.9
N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1))
ZERO_CONFIDENCE = 0.0
MAX_HISTORY_DEPTH = 3

Expand Down Expand Up @@ -63,20 +64,27 @@ def respond():
responses = []
confidences = []
for context in contexts:
response = generate_response(context, model, tokenizer)
if len(response) > 3:
# drop too short responses
responses += [response]
confidences += [DEFAULT_CONFIDENCE]
else:
responses += [""]
confidences += [ZERO_CONFIDENCE]
curr_responses = []
curr_confidences = []
for i in range(N_HYPOTHESES_TO_GENERATE):
response = generate_response(context, model, tokenizer)
if len(response) > 3:
# drop too short responses
curr_responses += [response]
curr_confidences += [DEFAULT_CONFIDENCE]
else:
curr_responses += [""]
curr_confidences += [ZERO_CONFIDENCE]

responses += [curr_responses]
confidences += [curr_confidences]

except Exception as exc:
logger.exception(exc)
sentry_sdk.capture_exception(exc)
responses = [""] * len(contexts)
confidences = [ZERO_CONFIDENCE] * len(contexts)
responses = [[""]] * len(contexts)
confidences = [[ZERO_CONFIDENCE]] * len(contexts)

total_time = time.time() - st_time
logger.info(f"masked_lm exec time: {total_time:.3f}s")
logger.info(f"dialogpt exec time: {total_time:.3f}s")
return jsonify(list(zip(responses, confidences)))
7 changes: 6 additions & 1 deletion services/dialogpt/test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import os
import requests


N_HYPOTHESES_TO_GENERATE = int(os.environ.get("N_HYPOTHESES_TO_GENERATE", 1))


def test_respond():
url = "http://0.0.0.0:8125/respond"

contexts = [["hi", "hi. how are you?"], ["let's chat about movies", "cool. what movies do you like?"]]
gold_result = [["I'm good, how are you?", 0.9], ["I like the new one.", 0.9]]
result = requests.post(url, json={"utterances_histories": contexts}).json()
assert [
len(sample[0]) > 0 and sample[1] > 0.0 for sample in result
len(sample[0]) > 0 and all([len(text) > 0 for text in sample[0]]) and all([conf > 0.0 for conf in sample[1]])
for sample in result
], f"Got\n{result}\n, but expected:\n{gold_result}"
print("Success")

Expand Down
8 changes: 1 addition & 7 deletions skills/dff_book_skill/scenario/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,7 @@
TRANSITIONS: {
("global_flow", "fallback", 1.5): loc_cnd.exit_skill,
("books_general", "dislikes_reading", 1.5): loc_cnd.dislikes_reading,
("books_general", "book_start", 5): cnd.all(
[
loc_cnd.is_proposed_skill,
cnd.neg(loc_cnd.check_flag("book_skill_active")),
cnd.neg(loc_cnd.check_flag("book_start_visited")),
]
),
("books_general", "book_start"): loc_cnd.start_condition,
("books_general", "book_restart"): cnd.all(
[
loc_cnd.is_proposed_skill,
Expand Down

0 comments on commit 149be3b

Please sign in to comment.