-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/dialogpt ru and dff-generative-skill (#97)
* Fix requirements.txt (#84) * feat: initialize dialogpt_RU * feat: files init * feat: basic integration of dialogpt_RU * fix: rename dialogpt * fix: dialogpt to device * fix: dialogpt final version * fix: dialogpt test * fix: dialogpt test * fix: dialogpt resources consumption * fix: dialogpt to tests * feat: dff generative skill * feat: dff generative skill * fix: remove extra files * fix: input to dialogpt * fix: input to dialogpt * fix: logging * fix: turn on tests * fix: get dialog from context * fix: get uttrs from context * fix: geempty uttrs * fix: return empty resp * fix: test file * fix: tests * fix: test ratio * add speech_function_* dist * add speech_function_* dist readme * added sf_functions * fix ports * fix:m codestyle * fix deployment config * fix: tests for generative skill * fix: codestyle * add formatters, fix pipeline * update speech function * sources * fix: check if dialogpt is ready * fix: wait services * rename book skill * remove old book skill, update usages * fix readme * fix codestyle * fix codestyle * fix codestyle * fix codestyle line length * move res_cor.json to shared files * fix itsdangerous requirements * pin itsdangerous requirements for all flask==1.1.1 servers * fix cpu.yml, dockerfiles and test for sfc, sfp * fix codestyle issues * blacked with -l 120 * following Dilya's holy orders * following Dilya's not so holy orders * fix formatters * fix pipeline * fix pipeline and formatters * Adding timeouts + mapping of book skill * removed old & irrelevant tests * we've set confidence to super level * feat: midas cls sent tokenize only if needed (#101) * feat: midas cls sent tokenize only if needed * feat: take into account tokenized uttrs by bot * fix: codestyle * fix: itsdangerous reqs * fix: docker reqs * fix: check another container * fix: rights for file * fix: coestyle * fix: return tests for intent responder * fix: revert intent responder * fix: review fixes * fix: codestyle Co-authored-by: Andrii.Hura <[email protected]> Co-authored-by: mtalimanchuk <[email protected]> Co-authored-by: Daniel Kornev <[email protected]>
- Loading branch information
1 parent
3a7f50c
commit 47461ec
Showing
28 changed files
with
979 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ services: | |
entity-detection: | ||
environment: | ||
CUDA_VISIBLE_DEVICES: "" | ||
|
||
dialogpt: | ||
environment: | ||
CUDA_VISIBLE_DEVICES: "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# syntax=docker/dockerfile:experimental | ||
|
||
FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime | ||
|
||
WORKDIR /src | ||
|
||
ARG PRETRAINED_MODEL_NAME_OR_PATH | ||
ENV PRETRAINED_MODEL_NAME_OR_PATH ${PRETRAINED_MODEL_NAME_OR_PATH} | ||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
|
||
COPY ./requirements.txt /src/requirements.txt | ||
RUN pip install -r /src/requirements.txt | ||
|
||
COPY . /src | ||
|
||
HEALTHCHECK --interval=5s --timeout=90s --retries=3 CMD curl --fail 127.0.0.1:${SERVICE_PORT}/healthcheck || exit 1 | ||
|
||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
GPU RAM = 1Gb | ||
cpu time = 0.15 sec | ||
gpu time = 0.05 sec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
transformers==4.0.1 | ||
sentencepiece==0.1.94 | ||
flask==1.1.1 | ||
gunicorn==19.9.0 | ||
requests==2.22.0 | ||
sentry-sdk[flask]==0.14.1 | ||
healthcheck==1.3.3 | ||
itsdangerous==2.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
Source code is https://github.com/Grossmend/DialoGPT/blob/master/src/service/service.py | ||
""" | ||
import logging | ||
import time | ||
import os | ||
from typing import Dict, List | ||
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
from flask import Flask, request, jsonify | ||
from healthcheck import HealthCheck | ||
import sentry_sdk | ||
from sentry_sdk.integrations.flask import FlaskIntegration | ||
|
||
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) | ||
|
||
|
||
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get( | ||
"PRETRAINED_MODEL_NAME_OR_PATH", "Grossmend/rudialogpt3_medium_based_on_gpt2" | ||
) | ||
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}") | ||
|
||
cuda = torch.cuda.is_available() | ||
if cuda: | ||
torch.cuda.set_device(0) | ||
device = "cuda" | ||
else: | ||
device = "cpu" | ||
|
||
logger.info(f"dialogpt is set to run on {device}") | ||
|
||
params_default = { | ||
"max_length": 256, | ||
"no_repeat_ngram_size": 3, | ||
"do_sample": True, | ||
"top_k": 100, | ||
"top_p": 0.9, | ||
"temperature": 0.6, | ||
"num_return_sequences": 3, | ||
"device": device, | ||
"is_always_use_length": True, | ||
"length_generate": "1", | ||
} | ||
|
||
|
||
class RussianDialogGPT: | ||
def __init__(self, path_model: str): | ||
self.path_model = path_model | ||
self.tokenizer = None | ||
self.model = None | ||
self._load_model() | ||
|
||
def _load_model(self): | ||
logger.info(f"dialogpt Loading model: {self.path_model} ...") | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) | ||
self.model = AutoModelForCausalLM.from_pretrained(self.path_model) | ||
|
||
def get_responses(self, inputs: List[Dict], params: Dict) -> List[str]: | ||
|
||
params_ = { | ||
"max_length": params.get("max_length", params_default["max_length"]), | ||
"no_repeat_ngram_size": params.get("no_repeat_ngram_size", params_default["no_repeat_ngram_size"]), | ||
"do_sample": params.get("do_sample", params_default["do_sample"]), | ||
"top_k": params.get("top_k", params_default["top_k"]), | ||
"top_p": params.get("top_p", params_default["top_p"]), | ||
"temperature": params.get("temperature", params_default["temperature"]), | ||
"num_return_sequences": params.get("num_return_sequences", params_default["num_return_sequences"]), | ||
"device": params.get("device", params_default["device"]), | ||
"is_always_use_length": params.get("is_always_use_length", params_default["is_always_use_length"]), | ||
"length_generate": params.get("length_generate", params_default["length_generate"]), | ||
} | ||
|
||
inputs_text = "" | ||
for input_ in inputs: | ||
if params_["is_always_use_length"]: | ||
length_rep = len(self.tokenizer.encode(input_["text"])) | ||
if length_rep <= 15: | ||
length_param = "1" | ||
elif length_rep <= 50: | ||
length_param = "2" | ||
elif length_rep <= 256: | ||
length_param = "3" | ||
else: | ||
length_param = "-" | ||
else: | ||
length_param = "-" | ||
inputs_text += f"|{input_['speaker']}|{length_param}|{input_['text']}" | ||
inputs_text += f"|1|{params_['length_generate']}|" | ||
|
||
inputs_token_ids = self.tokenizer.encode(inputs_text, return_tensors="pt") | ||
inputs_token_ids = inputs_token_ids.cuda() if cuda else inputs | ||
|
||
try: | ||
outputs_token_ids = self.model.generate( | ||
inputs_token_ids, | ||
max_length=params_["max_length"], | ||
no_repeat_ngram_size=params_["no_repeat_ngram_size"], | ||
do_sample=params_["do_sample"], | ||
top_k=params_["top_k"], | ||
top_p=params_["top_p"], | ||
temperature=params_["temperature"], | ||
num_return_sequences=params_["num_return_sequences"], | ||
device=params_["device"], | ||
mask_token_id=self.tokenizer.mask_token_id, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
unk_token_id=self.tokenizer.unk_token_id, | ||
pad_token_id=self.tokenizer.pad_token_id, | ||
) | ||
except Exception as e: | ||
logger.info(f"dialogpt Error generate: {str(e)}") | ||
return "" | ||
|
||
outputs = [self.tokenizer.decode(x, skip_special_tokens=True) for x in outputs_token_ids] | ||
outputs = [x.split("|")[-1] for x in outputs] | ||
# outputs contains list of strings of possible hypotheses | ||
return outputs | ||
|
||
|
||
try: | ||
model = RussianDialogGPT(PRETRAINED_MODEL_NAME_OR_PATH) | ||
model.model.eval() | ||
if cuda: | ||
model.model.cuda() | ||
|
||
logger.info("dialogpt model is ready") | ||
except Exception as e: | ||
sentry_sdk.capture_exception(e) | ||
logger.exception(e) | ||
raise e | ||
|
||
app = Flask(__name__) | ||
health = HealthCheck(app, "/healthcheck") | ||
logging.getLogger("werkzeug").setLevel("WARNING") | ||
|
||
|
||
@app.route("/respond", methods=["POST"]) | ||
def respond(): | ||
st_time = time.time() | ||
|
||
dialog_contexts = request.json.get("dialog_contexts", []) | ||
num_return_sequences = request.json.get("num_return_sequences", 3) | ||
|
||
try: | ||
batch_generated_responses = [] | ||
for context in dialog_contexts: | ||
# context is a list of dicts, each dict contains text and speaker label | ||
# context = [{"text": "utterance text", "speaker": "human"}, ...] | ||
inputs = [{"text": uttr["text"], "speaker": 1 if uttr["speaker"] == "bot" else 0} for uttr in context][-3:] | ||
hypotheses = model.get_responses(inputs, params={"num_return_sequences": num_return_sequences}) | ||
logger.info(f"dialogpt hypotheses: {hypotheses}") | ||
batch_generated_responses.append(hypotheses) | ||
|
||
except Exception as exc: | ||
logger.exception(exc) | ||
sentry_sdk.capture_exception(exc) | ||
batch_generated_responses = [[]] * len(dialog_contexts) | ||
|
||
total_time = time.time() - st_time | ||
logger.info(f"dialogpt exec time: {total_time:.3f}s") | ||
|
||
return jsonify({"generated_responses": batch_generated_responses}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import requests | ||
|
||
|
||
def test_respond(): | ||
url = "http://0.0.0.0:8091/respond" | ||
|
||
dialog_contexts = [ | ||
[ | ||
{"speaker": "human", "text": "Привет, как день прошел?"}, | ||
{"speaker": "bot", "text": "Хорошо, а у тебя как?"}, | ||
{"speaker": "human", "text": "Нормально, посоветуй фильм посмотреть"}, | ||
] | ||
] | ||
|
||
request_data = {"dialog_contexts": dialog_contexts} | ||
result = requests.post(url, json=request_data).json()["generated_responses"][0] | ||
|
||
assert len(result) == 3 and len(result[0]) > 0, f"Got\n{result}" | ||
print("Success!") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_respond() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python test.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,2 @@ | ||
GPU RAM = 1Gb | ||
cpu time = 0.15 sec | ||
gpu time = 0.05 sec | ||
GPU RAM = 2.1 Gb | ||
gpu time = 0.5 sec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
FROM python:3.9.1 | ||
# ###################### IMMUTABLE SECTION ###################################### | ||
# Do not change anything in this section | ||
WORKDIR /src | ||
|
||
COPY common/dff/requirements.txt . | ||
RUN pip install -r requirements.txt | ||
|
||
# ###################### CUSTOM SECTION ###################################### | ||
# Here you can make changes | ||
|
||
ARG SERVICE_NAME | ||
ENV SERVICE_NAME ${SERVICE_NAME} | ||
|
||
COPY skills/${SERVICE_NAME}/requirements.txt . | ||
RUN pip install -r requirements.txt | ||
RUN python -m nltk.downloader wordnet | ||
|
||
COPY skills/${SERVICE_NAME}/ ./ | ||
COPY ./common/ ./common/ | ||
|
||
ARG SERVICE_PORT | ||
ENV SERVICE_PORT ${SERVICE_PORT} | ||
|
||
# wait for a server answer ( INTERVAL + TIMEOUT ) * RETRIES seconds after that change stutus to unhealthy | ||
HEALTHCHECK --interval=5s --timeout=5s --retries=3 CMD curl --fail 127.0.0.1:${SERVICE_PORT}/healthcheck || exit 1 | ||
|
||
|
||
CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} |
Oops, something went wrong.